diff --git a/homeassistant/components/duke_energy/coordinator.py b/homeassistant/components/duke_energy/coordinator.py index a70c94e6fee..eac01f2ad39 100644 --- a/homeassistant/components/duke_energy/coordinator.py +++ b/homeassistant/components/duke_energy/coordinator.py @@ -24,6 +24,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.util import dt as dt_util +from homeassistant.util.unit_conversion import EnergyConverter from .const import DOMAIN @@ -146,6 +147,7 @@ class DukeEnergyCoordinator(DataUpdateCoordinator[None]): name=f"{name_prefix} Consumption", source=DOMAIN, statistic_id=consumption_statistic_id, + unit_class=EnergyConverter.UNIT_CLASS, unit_of_measurement=UnitOfEnergy.KILO_WATT_HOUR if meter["serviceType"] == "ELECTRIC" else UnitOfVolume.CENTUM_CUBIC_FEET, diff --git a/homeassistant/components/elvia/importer.py b/homeassistant/components/elvia/importer.py index caca787237c..40795458f66 100644 --- a/homeassistant/components/elvia/importer.py +++ b/homeassistant/components/elvia/importer.py @@ -20,6 +20,7 @@ from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.util import get_instance from homeassistant.const import UnitOfEnergy from homeassistant.util import dt as dt_util +from homeassistant.util.unit_conversion import EnergyConverter from .const import DOMAIN, LOGGER @@ -153,6 +154,7 @@ class ElviaImporter: name=f"{self.metering_point_id} Consumption", source=DOMAIN, statistic_id=statistic_id, + unit_class=EnergyConverter.UNIT_CLASS, unit_of_measurement=UnitOfEnergy.KILO_WATT_HOUR, ), statistics=statistics, diff --git a/homeassistant/components/ista_ecotrend/sensor.py b/homeassistant/components/ista_ecotrend/sensor.py index 0a8ed6e9ddb..95096375530 100644 --- a/homeassistant/components/ista_ecotrend/sensor.py +++ b/homeassistant/components/ista_ecotrend/sensor.py @@ -34,6 +34,7 @@ from homeassistant.helpers.device_registry import ( from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.typing import StateType from homeassistant.helpers.update_coordinator import CoordinatorEntity +from homeassistant.util.unit_conversion import EnergyConverter, VolumeConverter from .const import DOMAIN from .coordinator import IstaConfigEntry, IstaCoordinator @@ -49,6 +50,7 @@ class IstaSensorEntityDescription(SensorEntityDescription): """Ista EcoTrend Sensor Description.""" consumption_type: IstaConsumptionType + unit_class: str | None = None value_type: IstaValueType | None = None @@ -84,6 +86,7 @@ SENSOR_DESCRIPTIONS: tuple[IstaSensorEntityDescription, ...] = ( suggested_display_precision=1, consumption_type=IstaConsumptionType.HEATING, value_type=IstaValueType.ENERGY, + unit_class=EnergyConverter.UNIT_CLASS, ), IstaSensorEntityDescription( key=IstaSensorEntity.HEATING_COST, @@ -104,6 +107,7 @@ SENSOR_DESCRIPTIONS: tuple[IstaSensorEntityDescription, ...] = ( state_class=SensorStateClass.TOTAL, suggested_display_precision=1, consumption_type=IstaConsumptionType.HOT_WATER, + unit_class=VolumeConverter.UNIT_CLASS, ), IstaSensorEntityDescription( key=IstaSensorEntity.HOT_WATER_ENERGY, @@ -114,6 +118,7 @@ SENSOR_DESCRIPTIONS: tuple[IstaSensorEntityDescription, ...] = ( suggested_display_precision=1, consumption_type=IstaConsumptionType.HOT_WATER, value_type=IstaValueType.ENERGY, + unit_class=EnergyConverter.UNIT_CLASS, ), IstaSensorEntityDescription( key=IstaSensorEntity.HOT_WATER_COST, @@ -135,6 +140,7 @@ SENSOR_DESCRIPTIONS: tuple[IstaSensorEntityDescription, ...] = ( suggested_display_precision=1, entity_registry_enabled_default=False, consumption_type=IstaConsumptionType.WATER, + unit_class=VolumeConverter.UNIT_CLASS, ), IstaSensorEntityDescription( key=IstaSensorEntity.WATER_COST, @@ -276,6 +282,7 @@ class IstaSensor(CoordinatorEntity[IstaCoordinator], SensorEntity): "name": f"{self.device_entry.name} {self.name}", "source": DOMAIN, "statistic_id": statistic_id, + "unit_class": self.entity_description.unit_class, "unit_of_measurement": self.entity_description.native_unit_of_measurement, } if statistics: diff --git a/homeassistant/components/kitchen_sink/__init__.py b/homeassistant/components/kitchen_sink/__init__.py index e6a2e98bcaf..cb782b258d9 100644 --- a/homeassistant/components/kitchen_sink/__init__.py +++ b/homeassistant/components/kitchen_sink/__init__.py @@ -36,6 +36,11 @@ from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType from homeassistant.util import dt as dt_util +from homeassistant.util.unit_conversion import ( + EnergyConverter, + TemperatureConverter, + VolumeConverter, +) from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN @@ -254,6 +259,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": DOMAIN, "name": "Outdoor temperature", "statistic_id": f"{DOMAIN}:temperature_outdoor", + "unit_class": TemperatureConverter.UNIT_CLASS, "unit_of_measurement": UnitOfTemperature.CELSIUS, "mean_type": StatisticMeanType.ARITHMETIC, "has_sum": False, @@ -267,6 +273,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": DOMAIN, "name": "Energy consumption 1", "statistic_id": f"{DOMAIN}:energy_consumption_kwh", + "unit_class": EnergyConverter.UNIT_CLASS, "unit_of_measurement": UnitOfEnergy.KILO_WATT_HOUR, "mean_type": StatisticMeanType.NONE, "has_sum": True, @@ -279,6 +286,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": DOMAIN, "name": "Energy consumption 2", "statistic_id": f"{DOMAIN}:energy_consumption_mwh", + "unit_class": EnergyConverter.UNIT_CLASS, "unit_of_measurement": UnitOfEnergy.MEGA_WATT_HOUR, "mean_type": StatisticMeanType.NONE, "has_sum": True, @@ -293,6 +301,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": DOMAIN, "name": "Gas consumption 1", "statistic_id": f"{DOMAIN}:gas_consumption_m3", + "unit_class": VolumeConverter.UNIT_CLASS, "unit_of_measurement": UnitOfVolume.CUBIC_METERS, "mean_type": StatisticMeanType.NONE, "has_sum": True, @@ -307,6 +316,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": DOMAIN, "name": "Gas consumption 2", "statistic_id": f"{DOMAIN}:gas_consumption_ft3", + "unit_class": VolumeConverter.UNIT_CLASS, "unit_of_measurement": UnitOfVolume.CUBIC_FEET, "mean_type": StatisticMeanType.NONE, "has_sum": True, @@ -319,6 +329,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": RECORDER_DOMAIN, "name": None, "statistic_id": "sensor.statistics_issues_issue_1", + "unit_class": VolumeConverter.UNIT_CLASS, "unit_of_measurement": UnitOfVolume.CUBIC_METERS, "mean_type": StatisticMeanType.ARITHMETIC, "has_sum": False, @@ -331,6 +342,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": RECORDER_DOMAIN, "name": None, "statistic_id": "sensor.statistics_issues_issue_2", + "unit_class": None, "unit_of_measurement": "cats", "mean_type": StatisticMeanType.ARITHMETIC, "has_sum": False, @@ -343,6 +355,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": RECORDER_DOMAIN, "name": None, "statistic_id": "sensor.statistics_issues_issue_3", + "unit_class": VolumeConverter.UNIT_CLASS, "unit_of_measurement": UnitOfVolume.CUBIC_METERS, "mean_type": StatisticMeanType.ARITHMETIC, "has_sum": False, @@ -355,6 +368,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: "source": RECORDER_DOMAIN, "name": None, "statistic_id": "sensor.statistics_issues_issue_4", + "unit_class": VolumeConverter.UNIT_CLASS, "unit_of_measurement": UnitOfVolume.CUBIC_METERS, "mean_type": StatisticMeanType.ARITHMETIC, "has_sum": False, @@ -375,6 +389,7 @@ async def _insert_wrong_wind_direction_statistics(hass: HomeAssistant) -> None: "source": RECORDER_DOMAIN, "name": None, "statistic_id": "sensor.statistics_issues_issue_5", + "unit_class": None, "unit_of_measurement": DEGREE, "mean_type": StatisticMeanType.ARITHMETIC, "has_sum": False, diff --git a/homeassistant/components/mill/coordinator.py b/homeassistant/components/mill/coordinator.py index 1991cad213c..222e77efdf7 100644 --- a/homeassistant/components/mill/coordinator.py +++ b/homeassistant/components/mill/coordinator.py @@ -25,6 +25,7 @@ from homeassistant.const import UnitOfEnergy from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.util import dt as dt_util, slugify +from homeassistant.util.unit_conversion import EnergyConverter from .const import DOMAIN @@ -156,6 +157,7 @@ class MillHistoricDataUpdateCoordinator(DataUpdateCoordinator): name=f"{heater.name}", source=DOMAIN, statistic_id=statistic_id, + unit_class=EnergyConverter.UNIT_CLASS, unit_of_measurement=UnitOfEnergy.KILO_WATT_HOUR, ) async_add_external_statistics(self.hass, metadata, statistics) diff --git a/homeassistant/components/opower/coordinator.py b/homeassistant/components/opower/coordinator.py index e6fbbee0bb6..beac8971cd2 100644 --- a/homeassistant/components/opower/coordinator.py +++ b/homeassistant/components/opower/coordinator.py @@ -35,6 +35,7 @@ from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.aiohttp_client import async_create_clientsession from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.util import dt as dt_util +from homeassistant.util.unit_conversion import EnergyConverter, VolumeConverter from .const import CONF_LOGIN_DATA, CONF_TOTP_SECRET, CONF_UTILITY, DOMAIN @@ -149,6 +150,7 @@ class OpowerCoordinator(DataUpdateCoordinator[dict[str, Forecast]]): name=f"{name_prefix} cost", source=DOMAIN, statistic_id=cost_statistic_id, + unit_class=None, unit_of_measurement=None, ) compensation_metadata = StatisticMetaData( @@ -157,8 +159,14 @@ class OpowerCoordinator(DataUpdateCoordinator[dict[str, Forecast]]): name=f"{name_prefix} compensation", source=DOMAIN, statistic_id=compensation_statistic_id, + unit_class=None, unit_of_measurement=None, ) + consumption_unit_class = ( + EnergyConverter.UNIT_CLASS + if account.meter_type == MeterType.ELEC + else VolumeConverter.UNIT_CLASS + ) consumption_unit = ( UnitOfEnergy.KILO_WATT_HOUR if account.meter_type == MeterType.ELEC @@ -170,6 +178,7 @@ class OpowerCoordinator(DataUpdateCoordinator[dict[str, Forecast]]): name=f"{name_prefix} consumption", source=DOMAIN, statistic_id=consumption_statistic_id, + unit_class=consumption_unit_class, unit_of_measurement=consumption_unit, ) return_metadata = StatisticMetaData( @@ -178,6 +187,7 @@ class OpowerCoordinator(DataUpdateCoordinator[dict[str, Forecast]]): name=f"{name_prefix} return", source=DOMAIN, statistic_id=return_statistic_id, + unit_class=consumption_unit_class, unit_of_measurement=consumption_unit, ) diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index b1563d85d56..056f359bcec 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -54,6 +54,7 @@ CONTEXT_ID_AS_BINARY_SCHEMA_VERSION = 36 EVENT_TYPE_IDS_SCHEMA_VERSION = 37 STATES_META_SCHEMA_VERSION = 38 CIRCULAR_MEAN_SCHEMA_VERSION = 49 +UNIT_CLASS_SCHEMA_VERSION = 51 LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION = 28 LEGACY_STATES_EVENT_FOREIGN_KEYS_FIXED_SCHEMA_VERSION = 43 diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index d662416012f..a0f5c779c0e 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -574,13 +574,18 @@ class Recorder(threading.Thread): statistic_id: str, *, new_statistic_id: str | UndefinedType = UNDEFINED, + new_unit_class: str | None | UndefinedType = UNDEFINED, new_unit_of_measurement: str | None | UndefinedType = UNDEFINED, on_done: Callable[[], None] | None = None, ) -> None: """Update statistics metadata for a statistic_id.""" self.queue_task( UpdateStatisticsMetadataTask( - on_done, statistic_id, new_statistic_id, new_unit_of_measurement + on_done, + statistic_id, + new_statistic_id, + new_unit_class, + new_unit_of_measurement, ) ) diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index a0e82de9fe0..6e3200a5fcd 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -71,7 +71,7 @@ class LegacyBase(DeclarativeBase): """Base class for tables, used for schema migration.""" -SCHEMA_VERSION = 50 +SCHEMA_VERSION = 51 _LOGGER = logging.getLogger(__name__) @@ -756,6 +756,7 @@ class _StatisticsMeta: ) source: Mapped[str | None] = mapped_column(String(32)) unit_of_measurement: Mapped[str | None] = mapped_column(String(255)) + unit_class: Mapped[str | None] = mapped_column(String(255)) has_mean: Mapped[bool | None] = mapped_column(Boolean) has_sum: Mapped[bool | None] = mapped_column(Boolean) name: Mapped[str | None] = mapped_column(String(255)) diff --git a/homeassistant/components/recorder/entity_registry.py b/homeassistant/components/recorder/entity_registry.py index 904582b75f0..75b0df2bf44 100644 --- a/homeassistant/components/recorder/entity_registry.py +++ b/homeassistant/components/recorder/entity_registry.py @@ -9,6 +9,7 @@ from homeassistant.helpers import entity_registry as er from homeassistant.helpers.event import async_has_entity_registry_updated_listeners from .core import Recorder +from .statistics import async_update_statistics_metadata from .util import filter_unique_constraint_integrity_error, get_instance, session_scope _LOGGER = logging.getLogger(__name__) @@ -27,8 +28,8 @@ def async_setup(hass: HomeAssistant) -> None: assert event.data["action"] == "update" and "old_entity_id" in event.data old_entity_id = event.data["old_entity_id"] new_entity_id = event.data["entity_id"] - instance.async_update_statistics_metadata( - old_entity_id, new_statistic_id=new_entity_id + async_update_statistics_metadata( + hass, old_entity_id, new_statistic_id=new_entity_id ) instance.async_update_states_metadata( old_entity_id, new_entity_id=new_entity_id diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 708be5eab20..ed9d761e025 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -103,7 +103,11 @@ from .queries import ( migrate_single_short_term_statistics_row_to_timestamp, migrate_single_statistics_row_to_timestamp, ) -from .statistics import cleanup_statistics_timestamp_migration, get_start_time +from .statistics import ( + _PRIMARY_UNIT_CONVERTERS, + cleanup_statistics_timestamp_migration, + get_start_time, +) from .tasks import RecorderTask from .util import ( database_job_retry_wrapper, @@ -2037,6 +2041,21 @@ class _SchemaVersion50Migrator(_SchemaVersionMigrator, target_version=50): connection.execute(text("UPDATE statistics_meta SET has_mean=NULL")) +class _SchemaVersion51Migrator(_SchemaVersionMigrator, target_version=51): + def _apply_update(self) -> None: + """Version specific update method.""" + # Add unit class column to StatisticsMeta + _add_columns(self.session_maker, "statistics_meta", ["unit_class VARCHAR(255)"]) + with session_scope(session=self.session_maker()) as session: + connection = session.connection() + for conv in _PRIMARY_UNIT_CONVERTERS: + connection.execute( + update(StatisticsMeta) + .where(StatisticsMeta.unit_of_measurement.in_(conv.VALID_UNITS)) + .values(unit_class=conv.UNIT_CLASS) + ) + + def _migrate_statistics_columns_to_timestamp_removing_duplicates( hass: HomeAssistant, instance: Recorder, diff --git a/homeassistant/components/recorder/models/statistics.py b/homeassistant/components/recorder/models/statistics.py index be216923892..c4d6ccded31 100644 --- a/homeassistant/components/recorder/models/statistics.py +++ b/homeassistant/components/recorder/models/statistics.py @@ -70,6 +70,8 @@ class StatisticMetaData(TypedDict): name: str | None source: str statistic_id: str + unit_class: str | None + """Specifies the unit conversion class to use, if applicable.""" unit_of_measurement: str | None diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 2321da45bb9..c10808d5047 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -35,6 +35,7 @@ import voluptuous as vol from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT from homeassistant.core import HomeAssistant, callback, valid_entity_id from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.frame import report_usage from homeassistant.helpers.recorder import DATA_RECORDER from homeassistant.helpers.singleton import singleton from homeassistant.helpers.typing import UNDEFINED, UndefinedType @@ -193,43 +194,48 @@ QUERY_STATISTICS_SUMMARY_SUM = ( .label("rownum"), ) +_PRIMARY_UNIT_CONVERTERS: list[type[BaseUnitConverter]] = [ + ApparentPowerConverter, + AreaConverter, + BloodGlucoseConcentrationConverter, + ConductivityConverter, + DataRateConverter, + DistanceConverter, + DurationConverter, + ElectricCurrentConverter, + ElectricPotentialConverter, + EnergyConverter, + EnergyDistanceConverter, + InformationConverter, + MassConverter, + MassVolumeConcentrationConverter, + PowerConverter, + PressureConverter, + ReactiveEnergyConverter, + ReactivePowerConverter, + SpeedConverter, + TemperatureConverter, + UnitlessRatioConverter, + VolumeConverter, + VolumeFlowRateConverter, +] + +_SECONDARY_UNIT_CONVERTERS: list[type[BaseUnitConverter]] = [] STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { - **dict.fromkeys(ApparentPowerConverter.VALID_UNITS, ApparentPowerConverter), - **dict.fromkeys(AreaConverter.VALID_UNITS, AreaConverter), - **dict.fromkeys( - BloodGlucoseConcentrationConverter.VALID_UNITS, - BloodGlucoseConcentrationConverter, - ), - **dict.fromkeys( - MassVolumeConcentrationConverter.VALID_UNITS, MassVolumeConcentrationConverter - ), - **dict.fromkeys(ConductivityConverter.VALID_UNITS, ConductivityConverter), - **dict.fromkeys(DataRateConverter.VALID_UNITS, DataRateConverter), - **dict.fromkeys(DistanceConverter.VALID_UNITS, DistanceConverter), - **dict.fromkeys(DurationConverter.VALID_UNITS, DurationConverter), - **dict.fromkeys(ElectricCurrentConverter.VALID_UNITS, ElectricCurrentConverter), - **dict.fromkeys(ElectricPotentialConverter.VALID_UNITS, ElectricPotentialConverter), - **dict.fromkeys(EnergyConverter.VALID_UNITS, EnergyConverter), - **dict.fromkeys(EnergyDistanceConverter.VALID_UNITS, EnergyDistanceConverter), - **dict.fromkeys(InformationConverter.VALID_UNITS, InformationConverter), - **dict.fromkeys(MassConverter.VALID_UNITS, MassConverter), - **dict.fromkeys(PowerConverter.VALID_UNITS, PowerConverter), - **dict.fromkeys(PressureConverter.VALID_UNITS, PressureConverter), - **dict.fromkeys(ReactiveEnergyConverter.VALID_UNITS, ReactiveEnergyConverter), - **dict.fromkeys(ReactivePowerConverter.VALID_UNITS, ReactivePowerConverter), - **dict.fromkeys(SpeedConverter.VALID_UNITS, SpeedConverter), - **dict.fromkeys(TemperatureConverter.VALID_UNITS, TemperatureConverter), - **dict.fromkeys(UnitlessRatioConverter.VALID_UNITS, UnitlessRatioConverter), - **dict.fromkeys(VolumeConverter.VALID_UNITS, VolumeConverter), - **dict.fromkeys(VolumeFlowRateConverter.VALID_UNITS, VolumeFlowRateConverter), + unit: conv for conv in _PRIMARY_UNIT_CONVERTERS for unit in conv.VALID_UNITS } +"""Map of units to unit converter. +This map includes units which can be converted without knowing the unit class. +""" -UNIT_CLASSES = { - unit: converter.UNIT_CLASS - for unit, converter in STATISTIC_UNIT_TO_UNIT_CONVERTER.items() +UNIT_CLASS_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { + conv.UNIT_CLASS: conv + for conv in chain(_PRIMARY_UNIT_CONVERTERS, _SECONDARY_UNIT_CONVERTERS) } +"""Map of unit class to converter.""" + DATA_SHORT_TERM_STATISTICS_RUN_CACHE = "recorder_short_term_statistics_run_cache" @@ -315,14 +321,32 @@ class StatisticsRow(BaseStatisticsRow, total=False): change: float | None +def _get_unit_converter( + unit_class: str | None, from_unit: str | None +) -> type[BaseUnitConverter] | None: + """Return the unit converter for the given unit class and unit. + + The unit converter is determined from the unit class and unit if the unit class + and unit match, otherwise from the unit. + """ + if ( + conv := UNIT_CLASS_TO_UNIT_CONVERTER.get(unit_class) + ) is not None and from_unit in conv.VALID_UNITS: + return conv + if (conv := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(from_unit)) is not None: + return conv + return None + + def get_display_unit( hass: HomeAssistant, statistic_id: str, + unit_class: str | None, statistic_unit: str | None, ) -> str | None: """Return the unit which the statistic will be displayed in.""" - if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None: + if (converter := _get_unit_converter(unit_class, statistic_unit)) is None: return statistic_unit state_unit: str | None = statistic_unit @@ -337,13 +361,14 @@ def get_display_unit( def _get_statistic_to_display_unit_converter( + unit_class: str | None, statistic_unit: str | None, state_unit: str | None, requested_units: dict[str, str] | None, allow_none: bool = True, ) -> Callable[[float | None], float | None] | Callable[[float], float] | None: """Prepare a converter from the statistics unit to display unit.""" - if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None: + if (converter := _get_unit_converter(unit_class, statistic_unit)) is None: return None display_unit: str | None @@ -367,24 +392,25 @@ def _get_statistic_to_display_unit_converter( return converter.converter_factory(from_unit=statistic_unit, to_unit=display_unit) -def _get_display_to_statistic_unit_converter( +def _get_display_to_statistic_unit_converter_func( + unit_class: str | None, display_unit: str | None, statistic_unit: str | None, ) -> Callable[[float], float] | None: """Prepare a converter from the display unit to the statistics unit.""" if ( display_unit == statistic_unit - or (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None + or (converter := _get_unit_converter(unit_class, statistic_unit)) is None ): return None return converter.converter_factory(from_unit=display_unit, to_unit=statistic_unit) -def _get_unit_converter( - from_unit: str, to_unit: str +def _get_unit_converter_func( + unit_class: str | None, from_unit: str, to_unit: str ) -> Callable[[float | None], float | None] | None: """Prepare a converter from a unit to another unit.""" - for conv in STATISTIC_UNIT_TO_UNIT_CONVERTER.values(): + if (conv := _get_unit_converter(unit_class, from_unit)) is not None: if from_unit in conv.VALID_UNITS and to_unit in conv.VALID_UNITS: if from_unit == to_unit: return None @@ -394,9 +420,11 @@ def _get_unit_converter( raise HomeAssistantError -def can_convert_units(from_unit: str | None, to_unit: str | None) -> bool: +def can_convert_units( + unit_class: str | None, from_unit: str | None, to_unit: str | None +) -> bool: """Return True if it's possible to convert from from_unit to to_unit.""" - for converter in STATISTIC_UNIT_TO_UNIT_CONVERTER.values(): + if (converter := _get_unit_converter(unit_class, from_unit)) is not None: if from_unit in converter.VALID_UNITS and to_unit in converter.VALID_UNITS: return True return False @@ -863,18 +891,71 @@ def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: instance.statistics_meta_manager.delete(session, statistic_ids) +@callback +def async_update_statistics_metadata( + hass: HomeAssistant, + statistic_id: str, + *, + new_statistic_id: str | UndefinedType = UNDEFINED, + new_unit_class: str | None | UndefinedType = UNDEFINED, + new_unit_of_measurement: str | None | UndefinedType = UNDEFINED, + on_done: Callable[[], None] | None = None, + _called_from_ws_api: bool = False, +) -> None: + """Update statistics metadata for a statistic_id.""" + if new_unit_of_measurement is not UNDEFINED and new_unit_class is UNDEFINED: + if not _called_from_ws_api: + report_usage( + ( + "doesn't specify unit_class when calling " + "async_update_statistics_metadata" + ), + breaks_in_ha_version="2026.11", + exclude_integrations={DOMAIN}, + ) + + unit = new_unit_of_measurement + if unit in STATISTIC_UNIT_TO_UNIT_CONVERTER: + new_unit_class = STATISTIC_UNIT_TO_UNIT_CONVERTER[unit].UNIT_CLASS + else: + new_unit_class = None + + if TYPE_CHECKING: + # After the above check, new_unit_class is guaranteed to not be UNDEFINED + assert new_unit_class is not UNDEFINED + + if new_unit_of_measurement is not UNDEFINED and new_unit_class is not None: + if (converter := UNIT_CLASS_TO_UNIT_CONVERTER.get(new_unit_class)) is None: + raise HomeAssistantError(f"Unsupported unit_class: '{new_unit_class}'") + + if new_unit_of_measurement not in converter.VALID_UNITS: + raise HomeAssistantError( + f"Unsupported unit_of_measurement '{new_unit_of_measurement}' " + f"for unit_class '{new_unit_class}'" + ) + + get_instance(hass).async_update_statistics_metadata( + statistic_id, + new_statistic_id=new_statistic_id, + new_unit_class=new_unit_class, + new_unit_of_measurement=new_unit_of_measurement, + on_done=on_done, + ) + + def update_statistics_metadata( instance: Recorder, statistic_id: str, new_statistic_id: str | None | UndefinedType, + new_unit_class: str | None | UndefinedType, new_unit_of_measurement: str | None | UndefinedType, ) -> None: """Update statistics metadata for a statistic_id.""" statistics_meta_manager = instance.statistics_meta_manager - if new_unit_of_measurement is not UNDEFINED: + if new_unit_class is not UNDEFINED and new_unit_of_measurement is not UNDEFINED: with session_scope(session=instance.get_session()) as session: statistics_meta_manager.update_unit_of_measurement( - session, statistic_id, new_unit_of_measurement + session, statistic_id, new_unit_class, new_unit_of_measurement ) if new_statistic_id is not UNDEFINED and new_statistic_id is not None: with session_scope( @@ -926,13 +1007,16 @@ def _statistic_by_id_from_metadata( return { meta["statistic_id"]: { "display_unit_of_measurement": get_display_unit( - hass, meta["statistic_id"], meta["unit_of_measurement"] + hass, + meta["statistic_id"], + meta["unit_class"], + meta["unit_of_measurement"], ), "mean_type": meta["mean_type"], "has_sum": meta["has_sum"], "name": meta["name"], "source": meta["source"], - "unit_class": UNIT_CLASSES.get(meta["unit_of_measurement"]), + "unit_class": meta["unit_class"], "unit_of_measurement": meta["unit_of_measurement"], } for _, meta in metadata.values() @@ -1008,7 +1092,7 @@ def list_statistic_ids( "has_sum": meta["has_sum"], "name": meta["name"], "source": meta["source"], - "unit_class": UNIT_CLASSES.get(meta["unit_of_measurement"]), + "unit_class": meta["unit_class"], "unit_of_measurement": meta["unit_of_measurement"], } @@ -1744,10 +1828,13 @@ def statistic_during_period( else: result["change"] = None + unit_class = metadata[1]["unit_class"] state_unit = unit = metadata[1]["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) - convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) + convert = _get_statistic_to_display_unit_converter( + unit_class, unit, state_unit, units + ) if not convert: return result @@ -1830,10 +1917,13 @@ def _augment_result_with_change( metadata_by_id = _metadata[row.metadata_id] statistic_id = metadata_by_id["statistic_id"] + unit_class = metadata_by_id["unit_class"] state_unit = unit = metadata_by_id["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) - convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) + convert = _get_statistic_to_display_unit_converter( + unit_class, unit, state_unit, units + ) if convert is not None: prev_sums[statistic_id] = convert(row.sum) @@ -2426,11 +2516,12 @@ def _sorted_statistics_to_dict( metadata_by_id = metadata[meta_id] statistic_id = metadata_by_id["statistic_id"] if convert_units: + unit_class = metadata_by_id["unit_class"] state_unit = unit = metadata_by_id["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) convert = _get_statistic_to_display_unit_converter( - unit, state_unit, units, allow_none=False + unit_class, unit, state_unit, units, allow_none=False ) else: convert = None @@ -2501,6 +2592,27 @@ def _async_import_statistics( statistics: Iterable[StatisticData], ) -> None: """Validate timestamps and insert an import_statistics job in the queue.""" + # If unit class is not set, we try to set it based on the unit of measurement + # Note: This can't happen from the type checker's perspective, but we need + # to guard against custom integrations that have not been updated to set + # the unit_class. + if "unit_class" not in metadata: + unit = metadata["unit_of_measurement"] # type: ignore[unreachable] + if unit in STATISTIC_UNIT_TO_UNIT_CONVERTER: + metadata["unit_class"] = STATISTIC_UNIT_TO_UNIT_CONVERTER[unit].UNIT_CLASS + else: + metadata["unit_class"] = None + + if (unit_class := metadata["unit_class"]) is not None: + if (converter := UNIT_CLASS_TO_UNIT_CONVERTER.get(unit_class)) is None: + raise HomeAssistantError(f"Unsupported unit_class: '{unit_class}'") + + if metadata["unit_of_measurement"] not in converter.VALID_UNITS: + raise HomeAssistantError( + f"Unsupported unit_of_measurement '{metadata['unit_of_measurement']}' " + f"for unit_class '{unit_class}'" + ) + for statistic in statistics: start = statistic["start"] if start.tzinfo is None or start.tzinfo.utcoffset(start) is None: @@ -2532,6 +2644,8 @@ def async_import_statistics( hass: HomeAssistant, metadata: StatisticMetaData, statistics: Iterable[StatisticData], + *, + _called_from_ws_api: bool = False, ) -> None: """Import hourly statistics from an internal source. @@ -2544,6 +2658,13 @@ def async_import_statistics( if not metadata["source"] or metadata["source"] != DOMAIN: raise HomeAssistantError("Invalid source") + if "unit_class" not in metadata and not _called_from_ws_api: # type: ignore[unreachable] + report_usage( # type: ignore[unreachable] + "doesn't specify unit_class when calling async_import_statistics", + breaks_in_ha_version="2026.11", + exclude_integrations={DOMAIN}, + ) + _async_import_statistics(hass, metadata, statistics) @@ -2552,6 +2673,8 @@ def async_add_external_statistics( hass: HomeAssistant, metadata: StatisticMetaData, statistics: Iterable[StatisticData], + *, + _called_from_ws_api: bool = False, ) -> None: """Add hourly statistics from an external source. @@ -2566,6 +2689,13 @@ def async_add_external_statistics( if not metadata["source"] or metadata["source"] != domain: raise HomeAssistantError("Invalid source") + if "unit_class" not in metadata and not _called_from_ws_api: # type: ignore[unreachable] + report_usage( # type: ignore[unreachable] + "doesn't specify unit_class when calling async_add_external_statistics", + breaks_in_ha_version="2026.11", + exclude_integrations={DOMAIN}, + ) + _async_import_statistics(hass, metadata, statistics) @@ -2699,9 +2829,10 @@ def adjust_statistics( if statistic_id not in metadata: return True + unit_class = metadata[statistic_id][1]["unit_class"] statistic_unit = metadata[statistic_id][1]["unit_of_measurement"] - if convert := _get_display_to_statistic_unit_converter( - adjustment_unit, statistic_unit + if convert := _get_display_to_statistic_unit_converter_func( + unit_class, adjustment_unit, statistic_unit ): sum_adjustment = convert(sum_adjustment) @@ -2769,8 +2900,9 @@ def change_statistics_unit( return metadata_id = metadata[0] + unit_class = metadata[1]["unit_class"] - if not (convert := _get_unit_converter(old_unit, new_unit)): + if not (convert := _get_unit_converter_func(unit_class, old_unit, new_unit)): _LOGGER.warning( "Statistics unit of measurement for %s is already %s", statistic_id, @@ -2786,12 +2918,14 @@ def change_statistics_unit( _change_statistics_unit_for_table(session, table, metadata_id, convert) statistics_meta_manager.update_unit_of_measurement( - session, statistic_id, new_unit + session, + statistic_id, + unit_class, + new_unit, ) -@callback -def async_change_statistics_unit( +async def async_change_statistics_unit( hass: HomeAssistant, statistic_id: str, *, @@ -2799,7 +2933,17 @@ def async_change_statistics_unit( old_unit_of_measurement: str, ) -> None: """Change statistics unit for a statistic_id.""" - if not can_convert_units(old_unit_of_measurement, new_unit_of_measurement): + metadatas = await get_instance(hass).async_add_executor_job( + partial(get_metadata, hass, statistic_ids={statistic_id}) + ) + if statistic_id not in metadatas: + raise HomeAssistantError(f"No metadata found for {statistic_id}") + + metadata = metadatas[statistic_id][1] + + if not can_convert_units( + metadata["unit_class"], old_unit_of_measurement, new_unit_of_measurement + ): raise HomeAssistantError( f"Can't convert {old_unit_of_measurement} to {new_unit_of_measurement}" ) diff --git a/homeassistant/components/recorder/table_managers/statistics_meta.py b/homeassistant/components/recorder/table_managers/statistics_meta.py index 634e9565c12..0553f5e5327 100644 --- a/homeassistant/components/recorder/table_managers/statistics_meta.py +++ b/homeassistant/components/recorder/table_managers/statistics_meta.py @@ -13,9 +13,10 @@ from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import true from sqlalchemy.sql.lambdas import StatementLambdaElement -from ..const import CIRCULAR_MEAN_SCHEMA_VERSION +from ..const import CIRCULAR_MEAN_SCHEMA_VERSION, UNIT_CLASS_SCHEMA_VERSION from ..db_schema import StatisticsMeta from ..models import StatisticMeanType, StatisticMetaData +from ..statistics import STATISTIC_UNIT_TO_UNIT_CONVERTER from ..util import execute_stmt_lambda_element if TYPE_CHECKING: @@ -41,6 +42,7 @@ INDEX_UNIT_OF_MEASUREMENT: Final = 3 INDEX_HAS_SUM: Final = 4 INDEX_NAME: Final = 5 INDEX_MEAN_TYPE: Final = 6 +INDEX_UNIT_CLASS: Final = 7 def _generate_get_metadata_stmt( @@ -58,6 +60,8 @@ def _generate_get_metadata_stmt( columns.append(StatisticsMeta.mean_type) else: columns.append(StatisticsMeta.has_mean) + if schema_version >= UNIT_CLASS_SCHEMA_VERSION: + columns.append(StatisticsMeta.unit_class) stmt = lambda_stmt(lambda: select(*columns)) if statistic_ids: stmt += lambda q: q.where(StatisticsMeta.statistic_id.in_(statistic_ids)) @@ -140,6 +144,13 @@ class StatisticsMetaManager: if row[INDEX_MEAN_TYPE] else StatisticMeanType.NONE ) + if self.recorder.schema_version >= UNIT_CLASS_SCHEMA_VERSION: + unit_class = row[INDEX_UNIT_CLASS] + else: + conv = STATISTIC_UNIT_TO_UNIT_CONVERTER.get( + row[INDEX_UNIT_OF_MEASUREMENT] + ) + unit_class = conv.UNIT_CLASS if conv else None meta = { "has_mean": mean_type is StatisticMeanType.ARITHMETIC, "mean_type": mean_type, @@ -148,6 +159,7 @@ class StatisticsMetaManager: "source": row[INDEX_SOURCE], "statistic_id": statistic_id, "unit_of_measurement": row[INDEX_UNIT_OF_MEASUREMENT], + "unit_class": unit_class, } id_meta = (row_id, meta) results[statistic_id] = id_meta @@ -206,6 +218,7 @@ class StatisticsMetaManager: old_metadata["mean_type"] != new_metadata["mean_type"] or old_metadata["has_sum"] != new_metadata["has_sum"] or old_metadata["name"] != new_metadata["name"] + or old_metadata["unit_class"] != new_metadata["unit_class"] or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"] ): @@ -217,6 +230,7 @@ class StatisticsMetaManager: StatisticsMeta.mean_type: new_metadata["mean_type"], StatisticsMeta.has_sum: new_metadata["has_sum"], StatisticsMeta.name: new_metadata["name"], + StatisticsMeta.unit_class: new_metadata["unit_class"], StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"], }, synchronize_session=False, @@ -328,7 +342,11 @@ class StatisticsMetaManager: ) def update_unit_of_measurement( - self, session: Session, statistic_id: str, new_unit: str | None + self, + session: Session, + statistic_id: str, + new_unit_class: str | None, + new_unit: str | None, ) -> None: """Update the unit of measurement for a statistic_id. @@ -338,7 +356,12 @@ class StatisticsMetaManager: self._assert_in_recorder_thread() session.query(StatisticsMeta).filter( StatisticsMeta.statistic_id == statistic_id - ).update({StatisticsMeta.unit_of_measurement: new_unit}) + ).update( + { + StatisticsMeta.unit_of_measurement: new_unit, + StatisticsMeta.unit_class: new_unit_class, + } + ) self._clear_cache([statistic_id]) def update_statistic_id( diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index f5ad7f2a3d9..9ce021c59a5 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -77,6 +77,7 @@ class UpdateStatisticsMetadataTask(RecorderTask): on_done: Callable[[], None] | None statistic_id: str new_statistic_id: str | None | UndefinedType + new_unit_class: str | None | UndefinedType new_unit_of_measurement: str | None | UndefinedType def run(self, instance: Recorder) -> None: @@ -85,6 +86,7 @@ class UpdateStatisticsMetadataTask(RecorderTask): instance, self.statistic_id, self.new_statistic_id, + self.new_unit_class, self.new_unit_of_measurement, ) if self.on_done: diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 4f798fb86d0..2c682e2ae48 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from datetime import datetime as dt +import logging from typing import Any, Literal, cast import voluptuous as vol @@ -14,6 +15,7 @@ from homeassistant.core import HomeAssistant, callback, valid_entity_id from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.json import json_bytes +from homeassistant.helpers.typing import UNDEFINED from homeassistant.util import dt as dt_util from homeassistant.util.unit_conversion import ( ApparentPowerConverter, @@ -43,11 +45,12 @@ from homeassistant.util.unit_conversion import ( from .models import StatisticMeanType, StatisticPeriod from .statistics import ( - STATISTIC_UNIT_TO_UNIT_CONVERTER, + UNIT_CLASS_TO_UNIT_CONVERTER, async_add_external_statistics, async_change_statistics_unit, async_import_statistics, async_list_statistic_ids, + async_update_statistics_metadata, list_statistic_ids, statistic_during_period, statistics_during_period, @@ -56,6 +59,8 @@ from .statistics import ( ) from .util import PERIOD_SCHEMA, get_instance, resolve_period +_LOGGER = logging.getLogger(__name__) + CLEAR_STATISTICS_TIME_OUT = 10 UPDATE_STATISTICS_METADATA_TIME_OUT = 10 @@ -392,6 +397,7 @@ async def ws_get_statistics_metadata( { vol.Required("type"): "recorder/update_statistics_metadata", vol.Required("statistic_id"): str, + vol.Optional("unit_class"): vol.Any(str, None), vol.Required("unit_of_measurement"): vol.Any(str, None), } ) @@ -401,6 +407,8 @@ async def ws_update_statistics_metadata( ) -> None: """Update statistics metadata for a statistic_id. + The unit_class specifies which unit conversion class to use, if applicable. + Only the normalized unit of measurement can be updated. """ done_event = asyncio.Event() @@ -408,10 +416,20 @@ async def ws_update_statistics_metadata( def update_statistics_metadata_done() -> None: hass.loop.call_soon_threadsafe(done_event.set) - get_instance(hass).async_update_statistics_metadata( + if "unit_class" not in msg: + _LOGGER.warning( + "WS command recorder/update_statistics_metadata called without " + "specifying unit_class in metadata, this is deprecated and will " + "stop working in HA Core 2026.11" + ) + + async_update_statistics_metadata( + hass, msg["statistic_id"], + new_unit_class=msg.get("unit_class", UNDEFINED), new_unit_of_measurement=msg["unit_of_measurement"], on_done=update_statistics_metadata_done, + _called_from_ws_api=True, ) try: async with asyncio.timeout(UPDATE_STATISTICS_METADATA_TIME_OUT): @@ -434,15 +452,15 @@ async def ws_update_statistics_metadata( vol.Required("old_unit_of_measurement"): vol.Any(str, None), } ) -@callback -def ws_change_statistics_unit( +@websocket_api.async_response +async def ws_change_statistics_unit( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Change the unit_of_measurement for a statistic_id. All existing statistics will be converted to the new unit. """ - async_change_statistics_unit( + await async_change_statistics_unit( hass, msg["statistic_id"], new_unit_of_measurement=msg["new_unit_of_measurement"], @@ -487,17 +505,23 @@ async def ws_adjust_sum_statistics( return metadata = metadatas[0] - def valid_units(statistics_unit: str | None, adjustment_unit: str | None) -> bool: + def valid_units( + unit_class: str | None, statistics_unit: str | None, adjustment_unit: str | None + ) -> bool: if statistics_unit == adjustment_unit: return True - converter = STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistics_unit) - if converter is not None and adjustment_unit in converter.VALID_UNITS: + if ( + (converter := UNIT_CLASS_TO_UNIT_CONVERTER.get(unit_class)) is not None + and statistics_unit in converter.VALID_UNITS + and adjustment_unit in converter.VALID_UNITS + ): return True return False + unit_class = metadata["unit_class"] stat_unit = metadata["statistics_unit_of_measurement"] adjustment_unit = msg["adjustment_unit_of_measurement"] - if not valid_units(stat_unit, adjustment_unit): + if not valid_units(unit_class, stat_unit, adjustment_unit): connection.send_error( msg["id"], "invalid_units", @@ -521,6 +545,7 @@ async def ws_adjust_sum_statistics( vol.Required("name"): vol.Any(str, None), vol.Required("source"): str, vol.Required("statistic_id"): str, + vol.Optional("unit_class"): vol.Any(str, None), vol.Required("unit_of_measurement"): vol.Any(str, None), }, vol.Required("stats"): [ @@ -540,16 +565,25 @@ async def ws_adjust_sum_statistics( def ws_import_statistics( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: - """Import statistics.""" + """Import statistics. + + The unit_class specifies which unit conversion class to use, if applicable. + """ metadata = msg["metadata"] # The WS command will be changed in a follow up PR metadata["mean_type"] = ( StatisticMeanType.ARITHMETIC if metadata["has_mean"] else StatisticMeanType.NONE ) + if "unit_class" not in metadata: + _LOGGER.warning( + "WS command recorder/import_statistics called without specifying " + "unit_class in metadata, this is deprecated and will stop working " + "in HA Core 2026.11" + ) stats = msg["stats"] if valid_entity_id(metadata["statistic_id"]): - async_import_statistics(hass, metadata, stats) + async_import_statistics(hass, metadata, stats, _called_from_ws_api=True) else: - async_add_external_statistics(hass, metadata, stats) + async_add_external_statistics(hass, metadata, stats, _called_from_ws_api=True) connection.send_result(msg["id"]) diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index c20a3e2e1ae..07f4565a6c3 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -27,6 +27,7 @@ from homeassistant.components.recorder.models import ( StatisticResult, ) from homeassistant.const import ( + ATTR_DEVICE_CLASS, ATTR_UNIT_OF_MEASUREMENT, REVOLUTIONS_PER_MINUTE, UnitOfIrradiance, @@ -43,12 +44,14 @@ from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.enum import try_parse_enum from homeassistant.util.hass_dict import HassKey +from homeassistant.util.unit_conversion import BaseUnitConverter from .const import ( AMBIGUOUS_UNITS, ATTR_LAST_RESET, ATTR_STATE_CLASS, DOMAIN, + UNIT_CONVERTERS, SensorStateClass, UnitOfVolumeFlowRate, ) @@ -238,12 +241,41 @@ def _is_numeric(state: State) -> bool: return False +def _get_unit_class( + device_class: str | None, + unit: str | None, +) -> str | None: + """Return the unit class for the given device class and unit. + + The unit class is determined from the device class and unit if possible, + otherwise from the unit. + """ + if ( + device_class + and (conv := UNIT_CONVERTERS.get(device_class)) + and unit in conv.VALID_UNITS + ): + return conv.UNIT_CLASS + if conv := statistics.STATISTIC_UNIT_TO_UNIT_CONVERTER.get(unit): + return conv.UNIT_CLASS + return None + + +def _get_unit_converter( + unit_class: str | None, +) -> type[BaseUnitConverter] | None: + """Return the unit converter for the given unit class.""" + if not unit_class: + return None + return statistics.UNIT_CLASS_TO_UNIT_CONVERTER[unit_class] + + def _normalize_states( hass: HomeAssistant, old_metadatas: dict[str, tuple[int, StatisticMetaData]], fstates: list[tuple[float, State]], entity_id: str, -) -> tuple[str | None, list[tuple[float, State]]]: +) -> tuple[str | None, str | None, list[tuple[float, State]]]: """Normalize units.""" state_unit: str | None = None statistics_unit: str | None @@ -253,11 +285,16 @@ def _normalize_states( # We've not seen this sensor before, the first valid state determines the unit # used for statistics statistics_unit = state_unit + unit_class = _get_unit_class( + fstates[0][1].attributes.get(ATTR_DEVICE_CLASS), + state_unit, + ) else: # We have seen this sensor before, use the unit from metadata statistics_unit = old_metadata["unit_of_measurement"] + unit_class = old_metadata["unit_class"] - if statistics_unit not in statistics.STATISTIC_UNIT_TO_UNIT_CONVERTER: + if not (converter := _get_unit_converter(unit_class)): # The unit used by this sensor doesn't support unit conversion all_units = _get_units(fstates) @@ -283,11 +320,15 @@ def _normalize_states( extra, LINK_DEV_STATISTICS, ) - return None, [] + return None, None, [] - return state_unit, fstates + if state_unit != statistics_unit: + unit_class = _get_unit_class( + fstates[0][1].attributes.get(ATTR_DEVICE_CLASS), + state_unit, + ) + return unit_class, state_unit, fstates - converter = statistics.STATISTIC_UNIT_TO_UNIT_CONVERTER[statistics_unit] valid_fstates: list[tuple[float, State]] = [] convert: Callable[[float], float] | None = None last_unit: str | None | UndefinedType = UNDEFINED @@ -330,7 +371,7 @@ def _normalize_states( valid_fstates.append((fstate, state)) - return statistics_unit, valid_fstates + return unit_class, statistics_unit, valid_fstates def _suggest_report_issue(hass: HomeAssistant, entity_id: str) -> str: @@ -516,13 +557,15 @@ def compile_statistics( # noqa: C901 old_metadatas = statistics.get_metadata_with_session( get_instance(hass), session, statistic_ids=set(entities_with_float_states) ) - to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = [] + to_process: list[ + tuple[str, str | None, str | None, str, list[tuple[float, State]]] + ] = [] to_query: set[str] = set() for _state in sensor_states: entity_id = _state.entity_id if not (maybe_float_states := entities_with_float_states.get(entity_id)): continue - statistics_unit, valid_float_states = _normalize_states( + unit_class, statistics_unit, valid_float_states = _normalize_states( hass, old_metadatas, maybe_float_states, @@ -531,7 +574,9 @@ def compile_statistics( # noqa: C901 if not valid_float_states: continue state_class: str = _state.attributes[ATTR_STATE_CLASS] - to_process.append((entity_id, statistics_unit, state_class, valid_float_states)) + to_process.append( + (entity_id, unit_class, statistics_unit, state_class, valid_float_states) + ) if "sum" in wanted_statistics[entity_id].types: to_query.add(entity_id) @@ -540,6 +585,7 @@ def compile_statistics( # noqa: C901 ) for ( # pylint: disable=too-many-nested-blocks entity_id, + unit_class, statistics_unit, state_class, valid_float_states, @@ -604,6 +650,7 @@ def compile_statistics( # noqa: C901 "name": None, "source": RECORDER_DOMAIN, "statistic_id": entity_id, + "unit_class": unit_class, "unit_of_measurement": statistics_unit, } @@ -769,13 +816,17 @@ def list_statistic_ids( if "mean" in provided_statistics.types: mean_type = provided_statistics.mean_type + unit = attributes.get(ATTR_UNIT_OF_MEASUREMENT) + unit_class = _get_unit_class(attributes.get(ATTR_DEVICE_CLASS), unit) + result[entity_id] = { "mean_type": mean_type, "has_sum": has_sum, "name": None, "source": RECORDER_DOMAIN, "statistic_id": entity_id, - "unit_of_measurement": attributes.get(ATTR_UNIT_OF_MEASUREMENT), + "unit_class": unit_class, + "unit_of_measurement": unit, } return result diff --git a/homeassistant/components/suez_water/coordinator.py b/homeassistant/components/suez_water/coordinator.py index 55f1e4955c2..2ca6fc540df 100644 --- a/homeassistant/components/suez_water/coordinator.py +++ b/homeassistant/components/suez_water/coordinator.py @@ -25,6 +25,7 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryError from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed import homeassistant.util.dt as dt_util +from homeassistant.util.unit_conversion import VolumeConverter from .const import CONF_COUNTER_ID, DATA_REFRESH_INTERVAL, DOMAIN @@ -211,7 +212,10 @@ class SuezWaterCoordinator(DataUpdateCoordinator[SuezWaterData]): ) -> None: """Persist given statistics in recorder.""" consumption_metadata = self._get_statistics_metadata( - id=self._water_statistic_id, name="Consumption", unit=UnitOfVolume.LITERS + id=self._water_statistic_id, + name="Consumption", + unit=UnitOfVolume.LITERS, + unit_class=VolumeConverter.UNIT_CLASS, ) _LOGGER.debug( @@ -230,14 +234,17 @@ class SuezWaterCoordinator(DataUpdateCoordinator[SuezWaterData]): self._cost_statistic_id, ) cost_metadata = self._get_statistics_metadata( - id=self._cost_statistic_id, name="Cost", unit=CURRENCY_EURO + id=self._cost_statistic_id, + name="Cost", + unit=CURRENCY_EURO, + unit_class=None, ) async_add_external_statistics(self.hass, cost_metadata, cost_statistics) _LOGGER.debug("Updated statistics for %s", self._water_statistic_id) def _get_statistics_metadata( - self, id: str, name: str, unit: str + self, id: str, name: str, unit: str, unit_class: str | None ) -> StatisticMetaData: """Build statistics metadata for requested configuration.""" return StatisticMetaData( @@ -246,6 +253,7 @@ class SuezWaterCoordinator(DataUpdateCoordinator[SuezWaterData]): name=f"Suez water {name} {self._counter_id}", source=DOMAIN, statistic_id=id, + unit_class=unit_class, unit_of_measurement=unit, ) diff --git a/homeassistant/components/tibber/coordinator.py b/homeassistant/components/tibber/coordinator.py index 8335cc2d773..2e420957c43 100644 --- a/homeassistant/components/tibber/coordinator.py +++ b/homeassistant/components/tibber/coordinator.py @@ -24,6 +24,7 @@ from homeassistant.const import UnitOfEnergy from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.util import dt as dt_util +from homeassistant.util.unit_conversion import EnergyConverter from .const import DOMAIN @@ -70,15 +71,29 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]): async def _insert_statistics(self) -> None: """Insert Tibber statistics.""" for home in self._tibber_connection.get_homes(): - sensors: list[tuple[str, bool, str]] = [] + sensors: list[tuple[str, bool, str | None, str]] = [] if home.hourly_consumption_data: - sensors.append(("consumption", False, UnitOfEnergy.KILO_WATT_HOUR)) - sensors.append(("totalCost", False, home.currency)) + sensors.append( + ( + "consumption", + False, + EnergyConverter.UNIT_CLASS, + UnitOfEnergy.KILO_WATT_HOUR, + ) + ) + sensors.append(("totalCost", False, None, home.currency)) if home.hourly_production_data: - sensors.append(("production", True, UnitOfEnergy.KILO_WATT_HOUR)) - sensors.append(("profit", True, home.currency)) + sensors.append( + ( + "production", + True, + EnergyConverter.UNIT_CLASS, + UnitOfEnergy.KILO_WATT_HOUR, + ) + ) + sensors.append(("profit", True, None, home.currency)) - for sensor_type, is_production, unit in sensors: + for sensor_type, is_production, unit_class, unit in sensors: statistic_id = ( f"{DOMAIN}:energy_" f"{sensor_type.lower()}_" @@ -168,6 +183,7 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]): name=f"{home.name} {sensor_type}", source=DOMAIN, statistic_id=statistic_id, + unit_class=unit_class, unit_of_measurement=unit, ) async_add_external_statistics(self.hass, metadata, statistics) diff --git a/tests/components/energy/test_websocket_api.py b/tests/components/energy/test_websocket_api.py index 54f2a971fd4..af8233d46fd 100644 --- a/tests/components/energy/test_websocket_api.py +++ b/tests/components/energy/test_websocket_api.py @@ -370,6 +370,7 @@ async def test_fossil_energy_consumption_no_co2( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_energy_statistics_2 = ( @@ -404,6 +405,7 @@ async def test_fossil_energy_consumption_no_co2( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_2", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -535,6 +537,7 @@ async def test_fossil_energy_consumption_hole( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_energy_statistics_2 = ( @@ -569,6 +572,7 @@ async def test_fossil_energy_consumption_hole( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_2", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -698,6 +702,7 @@ async def test_fossil_energy_consumption_no_data( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_energy_statistics_2 = ( @@ -732,6 +737,7 @@ async def test_fossil_energy_consumption_no_data( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_2", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -850,6 +856,7 @@ async def test_fossil_energy_consumption( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_energy_statistics_2 = ( @@ -884,6 +891,7 @@ async def test_fossil_energy_consumption( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_2", + "unit_class": "energy", "unit_of_measurement": "Wh", } external_co2_statistics = ( @@ -914,6 +922,7 @@ async def test_fossil_energy_consumption( "name": "Fossil percentage", "source": "test", "statistic_id": "test:fossil_percentage", + "unit_class": None, "unit_of_measurement": "%", } @@ -1101,6 +1110,7 @@ async def test_fossil_energy_consumption_check_missing_hour( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1135,6 +1145,7 @@ async def test_fossil_energy_consumption_check_missing_hour( "name": "Fossil percentage", "source": "test", "statistic_id": "test:fossil_percentage", + "unit_class": None, "unit_of_measurement": "%", } @@ -1196,6 +1207,7 @@ async def test_fossil_energy_consumption_missing_sum( "name": "Mean imported energy", "source": "test", "statistic_id": "test:mean_energy_import_tariff", + "unit_class": "energy", "unit_of_measurement": "kWh", } diff --git a/tests/components/kitchen_sink/test_init.py b/tests/components/kitchen_sink/test_init.py index 526801aecfa..088a9e9c349 100644 --- a/tests/components/kitchen_sink/test_init.py +++ b/tests/components/kitchen_sink/test_init.py @@ -81,6 +81,7 @@ async def test_demo_statistics_growth(hass: HomeAssistant) -> None: "source": DOMAIN, "name": "Energy consumption 1", "statistic_id": statistic_id, + "unit_class": "volume", "unit_of_measurement": "m³", "has_mean": False, "has_sum": True, diff --git a/tests/components/opower/test_coordinator.py b/tests/components/opower/test_coordinator.py index 5f55fd481ba..29a27f66a0c 100644 --- a/tests/components/opower/test_coordinator.py +++ b/tests/components/opower/test_coordinator.py @@ -20,6 +20,7 @@ from homeassistant.const import UnitOfEnergy from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir from homeassistant.util import dt as dt_util +from homeassistant.util.unit_conversion import EnergyConverter from tests.common import MockConfigEntry from tests.components.recorder.common import async_wait_recording_done @@ -188,6 +189,7 @@ async def test_coordinator_migration( name="Opower pge elec 111111 consumption", source=DOMAIN, statistic_id=statistic_id, + unit_class=EnergyConverter.UNIT_CLASS, unit_of_measurement=UnitOfEnergy.KILO_WATT_HOUR, ) statistics_to_add = [ diff --git a/tests/components/recorder/auto_repairs/statistics/test_duplicates.py b/tests/components/recorder/auto_repairs/statistics/test_duplicates.py index 91f51b4e0c9..65d74f3651c 100644 --- a/tests/components/recorder/auto_repairs/statistics/test_duplicates.py +++ b/tests/components/recorder/auto_repairs/statistics/test_duplicates.py @@ -64,6 +64,7 @@ async def test_duplicate_statistics_handle_integrity_error( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import_tariff_1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_energy_statistics_1 = [ diff --git a/tests/components/recorder/db_schema_50.py b/tests/components/recorder/db_schema_50.py new file mode 100644 index 00000000000..02d02521525 --- /dev/null +++ b/tests/components/recorder/db_schema_50.py @@ -0,0 +1,892 @@ +"""Models for SQLAlchemy. + +This file contains the model definitions for schema version 50. +It is used to test the schema migration logic. +""" + +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime, timedelta +import logging +import time +from typing import Any, Final, Protocol, Self + +import ciso8601 +from fnv_hash_fast import fnv1a_32 +from sqlalchemy import ( + CHAR, + JSON, + BigInteger, + Boolean, + ColumnElement, + DateTime, + Float, + ForeignKey, + Identity, + Index, + Integer, + LargeBinary, + SmallInteger, + String, + Text, + case, + type_coerce, +) +from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm import DeclarativeBase, Mapped, aliased, mapped_column, relationship +from sqlalchemy.types import TypeDecorator + +from homeassistant.components.recorder.const import ( + ALL_DOMAIN_EXCLUDE_ATTRS, + SupportedDialect, +) +from homeassistant.components.recorder.models import ( + StatisticData, + StatisticDataTimestamp, + StatisticMeanType, + StatisticMetaData, + datetime_to_timestamp_or_none, + process_timestamp, + ulid_to_bytes_or_none, + uuid_hex_to_bytes_or_none, +) +from homeassistant.components.sensor import ATTR_STATE_CLASS +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + ATTR_FRIENDLY_NAME, + ATTR_UNIT_OF_MEASUREMENT, + MATCH_ALL, + MAX_LENGTH_EVENT_EVENT_TYPE, + MAX_LENGTH_STATE_ENTITY_ID, + MAX_LENGTH_STATE_STATE, +) +from homeassistant.core import Event, EventStateChangedData +from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null +from homeassistant.util import dt as dt_util + + +# SQLAlchemy Schema +class Base(DeclarativeBase): + """Base class for tables.""" + + +class LegacyBase(DeclarativeBase): + """Base class for tables, used for schema migration.""" + + +SCHEMA_VERSION = 50 + +_LOGGER = logging.getLogger(__name__) + +TABLE_EVENTS = "events" +TABLE_EVENT_DATA = "event_data" +TABLE_EVENT_TYPES = "event_types" +TABLE_STATES = "states" +TABLE_STATE_ATTRIBUTES = "state_attributes" +TABLE_STATES_META = "states_meta" +TABLE_RECORDER_RUNS = "recorder_runs" +TABLE_SCHEMA_CHANGES = "schema_changes" +TABLE_STATISTICS = "statistics" +TABLE_STATISTICS_META = "statistics_meta" +TABLE_STATISTICS_RUNS = "statistics_runs" +TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" +TABLE_MIGRATION_CHANGES = "migration_changes" + +STATISTICS_TABLES = ("statistics", "statistics_short_term") + +MAX_STATE_ATTRS_BYTES = 16384 +MAX_EVENT_DATA_BYTES = 32768 + +PSQL_DIALECT = SupportedDialect.POSTGRESQL + +ALL_TABLES = [ + TABLE_STATES, + TABLE_STATE_ATTRIBUTES, + TABLE_EVENTS, + TABLE_EVENT_DATA, + TABLE_EVENT_TYPES, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, + TABLE_MIGRATION_CHANGES, + TABLE_STATES_META, + TABLE_STATISTICS, + TABLE_STATISTICS_META, + TABLE_STATISTICS_RUNS, + TABLE_STATISTICS_SHORT_TERM, +] + +TABLES_TO_CHECK = [ + TABLE_STATES, + TABLE_EVENTS, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, +] + +LAST_UPDATED_INDEX_TS = "ix_states_last_updated_ts" +METADATA_ID_LAST_UPDATED_INDEX_TS = "ix_states_metadata_id_last_updated_ts" +EVENTS_CONTEXT_ID_BIN_INDEX = "ix_events_context_id_bin" +STATES_CONTEXT_ID_BIN_INDEX = "ix_states_context_id_bin" +LEGACY_STATES_EVENT_ID_INDEX = "ix_states_event_id" +LEGACY_STATES_ENTITY_ID_LAST_UPDATED_TS_INDEX = "ix_states_entity_id_last_updated_ts" +LEGACY_MAX_LENGTH_EVENT_CONTEXT_ID: Final = 36 +CONTEXT_ID_BIN_MAX_LENGTH = 16 + +MYSQL_COLLATE = "utf8mb4_unicode_ci" +MYSQL_DEFAULT_CHARSET = "utf8mb4" +MYSQL_ENGINE = "InnoDB" + +_DEFAULT_TABLE_ARGS = { + "mysql_default_charset": MYSQL_DEFAULT_CHARSET, + "mysql_collate": MYSQL_COLLATE, + "mysql_engine": MYSQL_ENGINE, + "mariadb_default_charset": MYSQL_DEFAULT_CHARSET, + "mariadb_collate": MYSQL_COLLATE, + "mariadb_engine": MYSQL_ENGINE, +} + +_MATCH_ALL_KEEP = { + ATTR_DEVICE_CLASS, + ATTR_STATE_CLASS, + ATTR_UNIT_OF_MEASUREMENT, + ATTR_FRIENDLY_NAME, +} + + +class UnusedDateTime(DateTime): + """An unused column type that behaves like a datetime.""" + + +class Unused(CHAR): + """An unused column type that behaves like a string.""" + + +@compiles(UnusedDateTime, "mysql", "mariadb", "sqlite") +@compiles(Unused, "mysql", "mariadb", "sqlite") +def compile_char_zero(type_: TypeDecorator, compiler: Any, **kw: Any) -> str: + """Compile UnusedDateTime and Unused as CHAR(0) on mysql, mariadb, and sqlite.""" + return "CHAR(0)" # Uses 1 byte on MySQL (no change on sqlite) + + +@compiles(Unused, "postgresql") +def compile_char_one(type_: TypeDecorator, compiler: Any, **kw: Any) -> str: + """Compile Unused as CHAR(1) on postgresql.""" + return "CHAR(1)" # Uses 1 byte + + +class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): + """Use ciso8601 to parse datetimes instead of sqlalchemy built-in regex.""" + + def result_processor(self, dialect: Dialect, coltype: Any) -> Callable | None: + """Offload the datetime parsing to ciso8601.""" + return lambda value: None if value is None else ciso8601.parse_datetime(value) + + +class NativeLargeBinary(LargeBinary): + """A faster version of LargeBinary for engines that support python bytes natively.""" + + def result_processor(self, dialect: Dialect, coltype: Any) -> Callable | None: + """No conversion needed for engines that support native bytes.""" + return None + + +# Although all integers are same in SQLite, it does not allow an identity column to be BIGINT +# https://sqlite.org/forum/info/2dfa968a702e1506e885cb06d92157d492108b22bf39459506ab9f7125bca7fd +ID_TYPE = BigInteger().with_variant(sqlite.INTEGER, "sqlite") +# For MariaDB and MySQL we can use an unsigned integer type since it will fit 2**32 +# for sqlite and postgresql we use a bigint +UINT_32_TYPE = BigInteger().with_variant( + mysql.INTEGER(unsigned=True), # type: ignore[no-untyped-call] + "mysql", + "mariadb", +) +JSON_VARIANT_CAST = Text().with_variant( + postgresql.JSON(none_as_null=True), + "postgresql", +) +JSONB_VARIANT_CAST = Text().with_variant( + postgresql.JSONB(none_as_null=True), + "postgresql", +) +DATETIME_TYPE = ( + DateTime(timezone=True) + .with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql", "mariadb") # type: ignore[no-untyped-call] + .with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") # type: ignore[no-untyped-call] +) +DOUBLE_TYPE = ( + Float() + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql", "mariadb") # type: ignore[no-untyped-call] + .with_variant(oracle.DOUBLE_PRECISION(), "oracle") + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") +) +UNUSED_LEGACY_COLUMN = Unused(0) +UNUSED_LEGACY_DATETIME_COLUMN = UnusedDateTime(timezone=True) +UNUSED_LEGACY_INTEGER_COLUMN = SmallInteger() +DOUBLE_PRECISION_TYPE_SQL = "DOUBLE PRECISION" +BIG_INTEGER_SQL = "BIGINT" +CONTEXT_BINARY_TYPE = LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH).with_variant( + NativeLargeBinary(CONTEXT_ID_BIN_MAX_LENGTH), "mysql", "mariadb", "sqlite" +) + +TIMESTAMP_TYPE = DOUBLE_TYPE + + +class _LiteralProcessorType(Protocol): + def __call__(self, value: Any) -> str: ... + + +class JSONLiteral(JSON): + """Teach SA how to literalize json.""" + + def literal_processor(self, dialect: Dialect) -> _LiteralProcessorType: + """Processor to convert a value to JSON.""" + + def process(value: Any) -> str: + """Dump json.""" + return JSON_DUMP(value) + + return process + + +class Events(Base): + """Event history data.""" + + __table_args__ = ( + # Used for fetching events at a specific time + # see logbook + Index( + "ix_events_event_type_id_time_fired_ts", "event_type_id", "time_fired_ts" + ), + Index( + EVENTS_CONTEXT_ID_BIN_INDEX, + "context_id_bin", + mysql_length=CONTEXT_ID_BIN_MAX_LENGTH, + mariadb_length=CONTEXT_ID_BIN_MAX_LENGTH, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_EVENTS + event_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + event_type: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + event_data: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + origin: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + origin_idx: Mapped[int | None] = mapped_column(SmallInteger) + time_fired: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + time_fired_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True) + context_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_user_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_parent_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + data_id: Mapped[int | None] = mapped_column( + ID_TYPE, ForeignKey("event_data.data_id"), index=True + ) + context_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_user_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_parent_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + event_type_id: Mapped[int | None] = mapped_column( + ID_TYPE, ForeignKey("event_types.event_type_id") + ) + event_data_rel: Mapped[EventData | None] = relationship("EventData") + event_type_rel: Mapped[EventTypes | None] = relationship("EventTypes") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + @property + def _time_fired_isotime(self) -> str | None: + """Return time_fired as an isotime string.""" + date_time: datetime | None + if self.time_fired_ts is not None: + date_time = dt_util.utc_from_timestamp(self.time_fired_ts) + else: + date_time = process_timestamp(self.time_fired) + if date_time is None: + return None + return date_time.isoformat(sep=" ", timespec="seconds") + + @staticmethod + def from_event(event: Event) -> Events: + """Create an event database object from a native event.""" + context = event.context + return Events( + event_type=None, + event_data=None, + origin_idx=event.origin.idx, + time_fired=None, + time_fired_ts=event.time_fired_timestamp, + context_id=None, + context_id_bin=ulid_to_bytes_or_none(context.id), + context_user_id=None, + context_user_id_bin=uuid_hex_to_bytes_or_none(context.user_id), + context_parent_id=None, + context_parent_id_bin=ulid_to_bytes_or_none(context.parent_id), + ) + + +class LegacyEvents(LegacyBase): + """Event history data with event_id, used for schema migration.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_EVENTS + event_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + context_id: Mapped[str | None] = mapped_column( + String(LEGACY_MAX_LENGTH_EVENT_CONTEXT_ID), index=True + ) + + +class EventData(Base): + """Event data history.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_EVENT_DATA + data_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + hash: Mapped[int | None] = mapped_column(UINT_32_TYPE, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_data: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql", "mariadb") + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + @staticmethod + def shared_data_bytes_from_event( + event: Event, dialect: SupportedDialect | None + ) -> bytes: + """Create shared_data from an event.""" + encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes + bytes_result = encoder(event.data) + if len(bytes_result) > MAX_EVENT_DATA_BYTES: + _LOGGER.warning( + "Event data for %s exceed maximum size of %s bytes. " + "This can cause database performance issues; Event data " + "will not be stored", + event.event_type, + MAX_EVENT_DATA_BYTES, + ) + return b"{}" + return bytes_result + + @staticmethod + def hash_shared_data_bytes(shared_data_bytes: bytes) -> int: + """Return the hash of json encoded shared data.""" + return fnv1a_32(shared_data_bytes) + + +class EventTypes(Base): + """Event type history.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_EVENT_TYPES + event_type_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + event_type: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_EVENT_TYPE), index=True, unique=True + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + +class States(Base): + """State change history.""" + + __table_args__ = ( + # Used for fetching the state of entities at a specific time + # (get_states in history.py) + Index(METADATA_ID_LAST_UPDATED_INDEX_TS, "metadata_id", "last_updated_ts"), + Index( + STATES_CONTEXT_ID_BIN_INDEX, + "context_id_bin", + mysql_length=CONTEXT_ID_BIN_MAX_LENGTH, + mariadb_length=CONTEXT_ID_BIN_MAX_LENGTH, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_STATES + state_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + entity_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + state: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_STATE)) + attributes: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + event_id: Mapped[int | None] = mapped_column(UNUSED_LEGACY_INTEGER_COLUMN) + last_changed: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + last_changed_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + last_reported_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + last_updated: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + last_updated_ts: Mapped[float | None] = mapped_column( + TIMESTAMP_TYPE, default=time.time, index=True + ) + old_state_id: Mapped[int | None] = mapped_column( + ID_TYPE, ForeignKey("states.state_id"), index=True + ) + attributes_id: Mapped[int | None] = mapped_column( + ID_TYPE, ForeignKey("state_attributes.attributes_id"), index=True + ) + context_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_user_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_parent_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + origin_idx: Mapped[int | None] = mapped_column( + SmallInteger + ) # 0 is local, 1 is remote + old_state: Mapped[States | None] = relationship("States", remote_side=[state_id]) + state_attributes: Mapped[StateAttributes | None] = relationship("StateAttributes") + context_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_user_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_parent_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + metadata_id: Mapped[int | None] = mapped_column( + ID_TYPE, ForeignKey("states_meta.metadata_id") + ) + states_meta_rel: Mapped[StatesMeta | None] = relationship("StatesMeta") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @property + def _last_updated_isotime(self) -> str | None: + """Return last_updated as an isotime string.""" + date_time: datetime | None + if self.last_updated_ts is not None: + date_time = dt_util.utc_from_timestamp(self.last_updated_ts) + else: + date_time = process_timestamp(self.last_updated) + if date_time is None: + return None + return date_time.isoformat(sep=" ", timespec="seconds") + + @staticmethod + def from_event(event: Event[EventStateChangedData]) -> States: + """Create object from a state_changed event.""" + state = event.data["new_state"] + # None state means the state was removed from the state machine + if state is None: + state_value = "" + last_updated_ts = event.time_fired_timestamp + last_changed_ts = None + last_reported_ts = None + else: + state_value = state.state + last_updated_ts = state.last_updated_timestamp + if state.last_updated == state.last_changed: + last_changed_ts = None + else: + last_changed_ts = state.last_changed_timestamp + if state.last_updated == state.last_reported: + last_reported_ts = None + else: + last_reported_ts = state.last_reported_timestamp + context = event.context + return States( + state=state_value, + entity_id=None, + attributes=None, + context_id=None, + context_id_bin=ulid_to_bytes_or_none(context.id), + context_user_id=None, + context_user_id_bin=uuid_hex_to_bytes_or_none(context.user_id), + context_parent_id=None, + context_parent_id_bin=ulid_to_bytes_or_none(context.parent_id), + origin_idx=event.origin.idx, + last_updated=None, + last_changed=None, + last_updated_ts=last_updated_ts, + last_changed_ts=last_changed_ts, + last_reported_ts=last_reported_ts, + ) + + +class LegacyStates(LegacyBase): + """State change history with entity_id, used for schema migration.""" + + __table_args__ = ( + Index( + LEGACY_STATES_ENTITY_ID_LAST_UPDATED_TS_INDEX, + "entity_id", + "last_updated_ts", + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_STATES + state_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + entity_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + last_updated_ts: Mapped[float | None] = mapped_column( + TIMESTAMP_TYPE, default=time.time, index=True + ) + context_id: Mapped[str | None] = mapped_column( + String(LEGACY_MAX_LENGTH_EVENT_CONTEXT_ID), index=True + ) + + +class StateAttributes(Base): + """State attribute change history.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_STATE_ATTRIBUTES + attributes_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + hash: Mapped[int | None] = mapped_column(UINT_32_TYPE, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_attrs: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql", "mariadb") + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def shared_attrs_bytes_from_event( + event: Event[EventStateChangedData], + dialect: SupportedDialect | None, + ) -> bytes: + """Create shared_attrs from a state_changed event.""" + # None state means the state was removed from the state machine + if (state := event.data["new_state"]) is None: + return b"{}" + if state_info := state.state_info: + unrecorded_attributes = state_info["unrecorded_attributes"] + exclude_attrs = { + *ALL_DOMAIN_EXCLUDE_ATTRS, + *unrecorded_attributes, + } + if MATCH_ALL in unrecorded_attributes: + # Don't exclude device class, state class, unit of measurement + # or friendly name when using the MATCH_ALL exclude constant + exclude_attrs.update(state.attributes) + exclude_attrs -= _MATCH_ALL_KEEP + else: + exclude_attrs = ALL_DOMAIN_EXCLUDE_ATTRS + encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes + bytes_result = encoder( + {k: v for k, v in state.attributes.items() if k not in exclude_attrs} + ) + if len(bytes_result) > MAX_STATE_ATTRS_BYTES: + _LOGGER.warning( + "State attributes for %s exceed maximum size of %s bytes. " + "This can cause database performance issues; Attributes " + "will not be stored", + state.entity_id, + MAX_STATE_ATTRS_BYTES, + ) + return b"{}" + return bytes_result + + @staticmethod + def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int: + """Return the hash of json encoded shared attributes.""" + return fnv1a_32(shared_attrs_bytes) + + +class StatesMeta(Base): + """Metadata for states.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_STATES_META + metadata_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + entity_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_STATE_ENTITY_ID), index=True, unique=True + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + +class StatisticsBase: + """Statistics base class.""" + + id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + created: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + created_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, default=time.time) + metadata_id: Mapped[int | None] = mapped_column( + ID_TYPE, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + ) + start: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + start_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True) + mean: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + mean_weight: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + min: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + max: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + last_reset: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + last_reset_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + state: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + sum: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + + duration: timedelta + + @classmethod + def from_stats( + cls, metadata_id: int, stats: StatisticData, now_timestamp: float | None = None + ) -> Self: + """Create object from a statistics with datetime objects.""" + return cls( # type: ignore[call-arg] + metadata_id=metadata_id, + created=None, + created_ts=now_timestamp or time.time(), + start=None, + start_ts=stats["start"].timestamp(), + mean=stats.get("mean"), + mean_weight=stats.get("mean_weight"), + min=stats.get("min"), + max=stats.get("max"), + last_reset=None, + last_reset_ts=datetime_to_timestamp_or_none(stats.get("last_reset")), + state=stats.get("state"), + sum=stats.get("sum"), + ) + + @classmethod + def from_stats_ts( + cls, + metadata_id: int, + stats: StatisticDataTimestamp, + now_timestamp: float | None = None, + ) -> Self: + """Create object from a statistics with timestamps.""" + return cls( # type: ignore[call-arg] + metadata_id=metadata_id, + created=None, + created_ts=now_timestamp or time.time(), + start=None, + start_ts=stats["start_ts"], + mean=stats.get("mean"), + mean_weight=stats.get("mean_weight"), + min=stats.get("min"), + max=stats.get("max"), + last_reset=None, + last_reset_ts=stats.get("last_reset_ts"), + state=stats.get("state"), + sum=stats.get("sum"), + ) + + +class Statistics(Base, StatisticsBase): + """Long term statistics.""" + + duration = timedelta(hours=1) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_STATISTICS + + +class _StatisticsShortTerm(StatisticsBase): + """Short term statistics.""" + + duration = timedelta(minutes=5) + + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticsShortTerm(Base, _StatisticsShortTerm): + """Short term statistics.""" + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), + _DEFAULT_TABLE_ARGS, + ) + + +class LegacyStatisticsShortTerm(LegacyBase, _StatisticsShortTerm): + """Short term statistics with 32-bit index, used for schema migration.""" + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), + _DEFAULT_TABLE_ARGS, + ) + + metadata_id: Mapped[int | None] = mapped_column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + use_existing_column=True, + ) + + +class _StatisticsMeta: + """Statistics meta data.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_STATISTICS_META + id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + statistic_id: Mapped[str | None] = mapped_column( + String(255), index=True, unique=True + ) + source: Mapped[str | None] = mapped_column(String(32)) + unit_of_measurement: Mapped[str | None] = mapped_column(String(255)) + has_mean: Mapped[bool | None] = mapped_column(Boolean) + has_sum: Mapped[bool | None] = mapped_column(Boolean) + name: Mapped[str | None] = mapped_column(String(255)) + mean_type: Mapped[StatisticMeanType] = mapped_column( + SmallInteger, nullable=False, default=StatisticMeanType.NONE.value + ) # See StatisticMeanType + + @staticmethod + def from_meta(meta: StatisticMetaData) -> StatisticsMeta: + """Create object from meta data.""" + return StatisticsMeta(**meta) + + +class StatisticsMeta(Base, _StatisticsMeta): + """Statistics meta data.""" + + +class LegacyStatisticsMeta(LegacyBase, _StatisticsMeta): + """Statistics meta data with 32-bit index, used for schema migration.""" + + id: Mapped[int] = mapped_column( + Integer, + Identity(), + primary_key=True, + use_existing_column=True, + ) + + +class RecorderRuns(Base): + """Representation of recorder run.""" + + __table_args__ = ( + Index("ix_recorder_runs_start_end", "start", "end"), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_RECORDER_RUNS + run_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + start: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + end: Mapped[datetime | None] = mapped_column(DATETIME_TYPE) + closed_incorrect: Mapped[bool] = mapped_column(Boolean, default=False) + created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + end = ( + f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None + ) + return ( + f"" + ) + + +class MigrationChanges(Base): + """Representation of migration changes.""" + + __tablename__ = TABLE_MIGRATION_CHANGES + __table_args__ = (_DEFAULT_TABLE_ARGS,) + + migration_id: Mapped[str] = mapped_column(String(255), primary_key=True) + version: Mapped[int] = mapped_column(SmallInteger) + + +class SchemaChanges(Base): + """Representation of schema version changes.""" + + __tablename__ = TABLE_SCHEMA_CHANGES + __table_args__ = (_DEFAULT_TABLE_ARGS,) + + change_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + schema_version: Mapped[int | None] = mapped_column(Integer) + changed: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + +class StatisticsRuns(Base): + """Representation of statistics run.""" + + __tablename__ = TABLE_STATISTICS_RUNS + __table_args__ = (_DEFAULT_TABLE_ARGS,) + + run_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) + start: Mapped[datetime] = mapped_column(DATETIME_TYPE, index=True) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +EVENT_DATA_JSON = type_coerce( + EventData.shared_data.cast(JSONB_VARIANT_CAST), JSONLiteral(none_as_null=True) +) +OLD_FORMAT_EVENT_DATA_JSON = type_coerce( + Events.event_data.cast(JSONB_VARIANT_CAST), JSONLiteral(none_as_null=True) +) + +SHARED_ATTRS_JSON = type_coerce( + StateAttributes.shared_attrs.cast(JSON_VARIANT_CAST), JSON(none_as_null=True) +) +OLD_FORMAT_ATTRS_JSON = type_coerce( + States.attributes.cast(JSON_VARIANT_CAST), JSON(none_as_null=True) +) + +ENTITY_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["entity_id"] +OLD_ENTITY_ID_IN_EVENT: ColumnElement = OLD_FORMAT_EVENT_DATA_JSON["entity_id"] +DEVICE_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["device_id"] +OLD_STATE = aliased(States, name="old_state") + +SHARED_ATTR_OR_LEGACY_ATTRIBUTES = case( + (StateAttributes.shared_attrs.is_(None), States.attributes), + else_=StateAttributes.shared_attrs, +).label("attributes") +SHARED_DATA_OR_LEGACY_EVENT_DATA = case( + (EventData.shared_data.is_(None), Events.event_data), else_=EventData.shared_data +).label("event_data") diff --git a/tests/components/recorder/table_managers/test_statistics_meta.py b/tests/components/recorder/table_managers/test_statistics_meta.py index 1af60b71ed5..280fec37bb4 100644 --- a/tests/components/recorder/table_managers/test_statistics_meta.py +++ b/tests/components/recorder/table_managers/test_statistics_meta.py @@ -87,6 +87,7 @@ async def test_invalid_mean_types( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.energy", + "unit_class": None, "unit_of_measurement": "kWh", }, ), @@ -99,6 +100,7 @@ async def test_invalid_mean_types( "name": "Wind direction", "source": "recorder", "statistic_id": "sensor.wind_direction", + "unit_class": None, "unit_of_measurement": DEGREE, }, ), @@ -111,6 +113,7 @@ async def test_invalid_mean_types( "name": "Wind speed", "source": "recorder", "statistic_id": "sensor.wind_speed", + "unit_class": None, "unit_of_measurement": "km/h", }, ), diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 74d319bcd97..43fb86eff32 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -562,6 +562,7 @@ async def test_events_during_migration_queue_exhausted( (25, False), (43, False), (48, True), + (50, True), ], ) async def test_schema_migrate( diff --git a/tests/components/recorder/test_migration_from_schema_50.py b/tests/components/recorder/test_migration_from_schema_50.py new file mode 100644 index 00000000000..238c7433f25 --- /dev/null +++ b/tests/components/recorder/test_migration_from_schema_50.py @@ -0,0 +1,291 @@ +"""The tests for the recorder filter matching the EntityFilter component.""" + +import importlib +import sys +import threading +from unittest.mock import patch + +import pytest +from pytest_unordered import unordered +from sqlalchemy import create_engine, inspect +from sqlalchemy.orm import Session + +from homeassistant.components import recorder +from homeassistant.components.recorder import core, migration, statistics +from homeassistant.components.recorder.const import UNIT_CLASS_SCHEMA_VERSION +from homeassistant.components.recorder.db_schema import StatisticsMeta +from homeassistant.components.recorder.models import StatisticMeanType +from homeassistant.components.recorder.util import session_scope +from homeassistant.core import HomeAssistant + +from .common import ( + async_recorder_block_till_done, + async_wait_recording_done, + get_patched_live_version, +) +from .conftest import instrument_migration + +from tests.common import async_test_home_assistant +from tests.typing import RecorderInstanceContextManager + +CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" +SCHEMA_MODULE_50 = "tests.components.recorder.db_schema_50" + + +@pytest.fixture +async def mock_recorder_before_hass( + async_test_recorder: RecorderInstanceContextManager, +) -> None: + """Set up recorder.""" + + +async def _async_wait_migration_done(hass: HomeAssistant) -> None: + """Wait for the migration to be done.""" + await recorder.get_instance(hass).async_block_till_done() + await async_recorder_block_till_done(hass) + + +def _create_engine_test(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + importlib.import_module(SCHEMA_MODULE_50) + old_db_schema = sys.modules[SCHEMA_MODULE_50] + engine = create_engine(*args, **kwargs) + old_db_schema.Base.metadata.create_all(engine) + with Session(engine) as session: + session.add( + recorder.db_schema.StatisticsRuns(start=statistics.get_start_time()) + ) + session.add( + recorder.db_schema.SchemaChanges( + schema_version=old_db_schema.SCHEMA_VERSION + ) + ) + session.commit() + return engine + + +@pytest.fixture +def db_schema_50(): + """Fixture to initialize the db with the old schema.""" + importlib.import_module(SCHEMA_MODULE_50) + old_db_schema = sys.modules[SCHEMA_MODULE_50] + + with ( + patch.object(recorder, "db_schema", old_db_schema), + patch.object(migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION), + patch.object( + migration, + "LIVE_MIGRATION_MIN_SCHEMA_VERSION", + get_patched_live_version(old_db_schema), + ), + patch.object(migration, "non_live_data_migration_needed", return_value=False), + patch.object(core, "StatesMeta", old_db_schema.StatesMeta), + patch.object(core, "EventTypes", old_db_schema.EventTypes), + patch.object(core, "EventData", old_db_schema.EventData), + patch.object(core, "States", old_db_schema.States), + patch.object(core, "Events", old_db_schema.Events), + patch.object(core, "StateAttributes", old_db_schema.StateAttributes), + patch(CREATE_ENGINE_TARGET, new=_create_engine_test), + ): + yield + + +@pytest.mark.parametrize("persistent_database", [True]) +@pytest.mark.usefixtures("hass_storage") # Prevent test hass from writing to storage +async def test_migrate_statistics_meta( + async_test_recorder: RecorderInstanceContextManager, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test migration of metadata adding unit_class.""" + importlib.import_module(SCHEMA_MODULE_50) + old_db_schema = sys.modules[SCHEMA_MODULE_50] + + def _insert_metadata(): + with session_scope(hass=hass) as session: + session.add_all( + ( + old_db_schema.StatisticsMeta( + statistic_id="sensor.test1", + source="recorder", + unit_of_measurement="kWh", + has_mean=None, + has_sum=True, + name="Test 1", + mean_type=StatisticMeanType.NONE, + ), + old_db_schema.StatisticsMeta( + statistic_id="sensor.test2", + source="recorder", + unit_of_measurement="cats", + has_mean=None, + has_sum=True, + name="Test 2", + mean_type=StatisticMeanType.NONE, + ), + old_db_schema.StatisticsMeta( + statistic_id="sensor.test3", + source="recorder", + unit_of_measurement="ppm", + has_mean=None, + has_sum=True, + name="Test 3", + mean_type=StatisticMeanType.NONE, + ), + ) + ) + + # Create database with old schema + with ( + patch.object(recorder, "db_schema", old_db_schema), + patch.object(migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION), + patch.object( + migration, + "LIVE_MIGRATION_MIN_SCHEMA_VERSION", + get_patched_live_version(old_db_schema), + ), + patch.object(migration.EventsContextIDMigration, "migrate_data"), + patch(CREATE_ENGINE_TARGET, new=_create_engine_test), + ): + async with ( + async_test_home_assistant() as hass, + async_test_recorder(hass) as instance, + ): + await instance.async_add_executor_job(_insert_metadata) + + await async_wait_recording_done(hass) + await _async_wait_migration_done(hass) + + await hass.async_stop() + await hass.async_block_till_done() + + def _object_as_dict(obj): + return {c.key: getattr(obj, c.key) for c in inspect(obj).mapper.column_attrs} + + def _fetch_metadata(): + with session_scope(hass=hass) as session: + metadatas = session.query(StatisticsMeta).all() + return { + metadata.statistic_id: _object_as_dict(metadata) + for metadata in metadatas + } + + # Run again with new schema, let migration run + async with async_test_home_assistant() as hass: + with ( + instrument_migration(hass) as instrumented_migration, + ): + # Stall migration when the last non-live schema migration is done + instrumented_migration.stall_on_schema_version = UNIT_CLASS_SCHEMA_VERSION + async with async_test_recorder( + hass, wait_recorder=False, wait_recorder_setup=False + ) as instance: + # Wait for migration to reach migration of unit class + await hass.async_add_executor_job( + instrumented_migration.apply_update_stalled.wait + ) + + # Check that it's possible to read metadata via the API, this will + # stop working when version 50 is migrated off line + pre_migration_metadata_api = await instance.async_add_executor_job( + statistics.list_statistic_ids, + hass, + None, + None, + ) + + instrumented_migration.migration_stall.set() + instance.recorder_and_worker_thread_ids.add(threading.get_ident()) + + await hass.async_block_till_done() + await async_wait_recording_done(hass) + await async_wait_recording_done(hass) + + post_migration_metadata_db = await instance.async_add_executor_job( + _fetch_metadata + ) + post_migration_metadata_api = await instance.async_add_executor_job( + statistics.list_statistic_ids, + hass, + None, + None, + ) + + await hass.async_stop() + await hass.async_block_till_done() + + assert pre_migration_metadata_api == unordered( + [ + { + "display_unit_of_measurement": "kWh", + "has_mean": False, + "has_sum": True, + "mean_type": StatisticMeanType.NONE, + "name": "Test 1", + "source": "recorder", + "statistic_id": "sensor.test1", + "statistics_unit_of_measurement": "kWh", + "unit_class": "energy", + }, + { + "display_unit_of_measurement": "cats", + "has_mean": False, + "has_sum": True, + "mean_type": StatisticMeanType.NONE, + "name": "Test 2", + "source": "recorder", + "statistic_id": "sensor.test2", + "statistics_unit_of_measurement": "cats", + "unit_class": None, + }, + { + "display_unit_of_measurement": "ppm", + "has_mean": False, + "has_sum": True, + "mean_type": StatisticMeanType.NONE, + "name": "Test 3", + "source": "recorder", + "statistic_id": "sensor.test3", + "statistics_unit_of_measurement": "ppm", + "unit_class": "unitless", + }, + ] + ) + assert post_migration_metadata_db == { + "sensor.test1": { + "has_mean": None, + "has_sum": True, + "id": 1, + "mean_type": 0, + "name": "Test 1", + "source": "recorder", + "statistic_id": "sensor.test1", + "unit_class": "energy", + "unit_of_measurement": "kWh", + }, + "sensor.test2": { + "has_mean": None, + "has_sum": True, + "id": 2, + "mean_type": 0, + "name": "Test 2", + "source": "recorder", + "statistic_id": "sensor.test2", + "unit_class": None, + "unit_of_measurement": "cats", + }, + "sensor.test3": { + "has_mean": None, + "has_sum": True, + "id": 3, + "mean_type": 0, + "name": "Test 3", + "source": "recorder", + "statistic_id": "sensor.test3", + "unit_class": "unitless", + "unit_of_measurement": "ppm", + }, + } + assert post_migration_metadata_api == unordered(pre_migration_metadata_api) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index d29ee04a469..8468865d058 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -29,6 +29,7 @@ from homeassistant.components.recorder.statistics import ( async_add_external_statistics, async_import_statistics, async_list_statistic_ids, + async_update_statistics_metadata, get_last_short_term_statistics, get_last_statistics, get_latest_short_term_statistics_with_session, @@ -48,6 +49,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util +from homeassistant.util.unit_system import METRIC_SYSTEM from .common import ( assert_dict_of_states_equal_without_context_and_last_changed, @@ -63,6 +65,12 @@ from tests.common import MockPlatform, MockUser, mock_platform from tests.typing import RecorderInstanceContextManager, WebSocketGenerator from tests.util.test_unit_conversion import _ALL_CONVERTERS +POWER_SENSOR_KW_ATTRIBUTES = { + "device_class": "power", + "state_class": "measurement", + "unit_of_measurement": "kW", +} + @pytest.fixture def multiple_start_time_chunk_sizes( @@ -397,6 +405,7 @@ def mock_sensor_statistics(): "has_sum": False, "name": None, "statistic_id": entity_id, + "unit_class": None, "unit_of_measurement": "dogs", }, "stat": {"start": start}, @@ -839,7 +848,18 @@ async def test_statistics_duplicated( caplog.clear() +# Integration frame mocked because of deprecation warnings about missing +# unit_class, can be removed in HA Core 2025.11 +@pytest.mark.parametrize("integration_frame_path", ["custom_components/my_integration"]) +@pytest.mark.usefixtures("mock_integration_frame") @pytest.mark.parametrize("last_reset_str", ["2022-01-01T00:00:00+02:00", None]) +@pytest.mark.parametrize( + ("external_metadata_extra"), + [ + {}, + {"unit_class": "energy"}, + ], +) @pytest.mark.parametrize( ("source", "statistic_id", "import_fn"), [ @@ -852,6 +872,7 @@ async def test_import_statistics( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, + external_metadata_extra: dict[str, str], source, statistic_id, import_fn, @@ -889,7 +910,7 @@ async def test_import_statistics( "source": source, "statistic_id": statistic_id, "unit_of_measurement": "kWh", - } + } | external_metadata_extra import_fn(hass, external_metadata, (external_statistics1, external_statistics2)) await async_wait_recording_done(hass) @@ -939,6 +960,7 @@ async def test_import_statistics( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, + "unit_class": "energy", "unit_of_measurement": "kWh", }, ) @@ -1031,6 +1053,7 @@ async def test_import_statistics( "name": "Total imported energy renamed", "source": source, "statistic_id": statistic_id, + "unit_class": "energy", "unit_of_measurement": "kWh", }, ) @@ -1119,6 +1142,7 @@ async def test_external_statistics_errors( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1207,6 +1231,7 @@ async def test_import_statistics_errors( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1270,6 +1295,213 @@ async def test_import_statistics_errors( assert get_metadata(hass, statistic_ids={"sensor.total_energy_import"}) == {} +# Integration frame mocked because of deprecation warnings about missing +# unit_class, can be removed in HA Core 2025.11 +@pytest.mark.parametrize("integration_frame_path", ["custom_components/my_integration"]) +@pytest.mark.usefixtures("mock_integration_frame") +@pytest.mark.parametrize( + ( + "requested_new_unit", + "update_statistics_extra", + "new_unit", + "new_unit_class", + "new_display_unit", + ), + [ + ("dogs", {}, "dogs", None, "dogs"), + ("dogs", {"new_unit_class": None}, "dogs", None, "dogs"), + (None, {}, None, "unitless", None), + (None, {"new_unit_class": "unitless"}, None, "unitless", None), + ("W", {}, "W", "power", "kW"), + ("W", {"new_unit_class": "power"}, "W", "power", "kW"), + # Note: Display unit is guessed even if unit_class is None + ("W", {"new_unit_class": None}, "W", None, "kW"), + ], +) +@pytest.mark.usefixtures("recorder_mock") +async def test_update_statistics_metadata( + hass: HomeAssistant, + requested_new_unit, + update_statistics_extra, + new_unit, + new_unit_class, + new_display_unit, +) -> None: + """Test removing statistics.""" + now = get_start_time(dt_util.utcnow()) + + units = METRIC_SYSTEM + attributes = POWER_SENSOR_KW_ATTRIBUTES | {"device_class": None} + state = 10 + + hass.config.units = units + await async_setup_component(hass, "sensor", {}) + await async_recorder_block_till_done(hass) + hass.states.async_set( + "sensor.test", state, attributes=attributes, timestamp=now.timestamp() + ) + await async_wait_recording_done(hass) + + do_adhoc_statistics(hass, period="hourly", start=now) + await async_recorder_block_till_done(hass) + + statistic_ids = await async_list_statistic_ids(hass) + assert statistic_ids == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": "kW", + "has_mean": True, + "mean_type": StatisticMeanType.ARITHMETIC, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": "kW", + "unit_class": "power", + } + ] + + async_update_statistics_metadata( + hass, + "sensor.test", + new_unit_of_measurement=requested_new_unit, + **update_statistics_extra, + ) + await async_recorder_block_till_done(hass) + + statistic_ids = await async_list_statistic_ids(hass) + assert statistic_ids == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": new_display_unit, + "has_mean": True, + "mean_type": StatisticMeanType.ARITHMETIC, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": new_unit, + "unit_class": new_unit_class, + } + ] + + assert statistics_during_period( + hass, + now, + period="5minute", + statistic_ids={"sensor.test"}, + units={"power": "W"}, + ) == { + "sensor.test": [ + { + "end": (now + timedelta(minutes=5)).timestamp(), + "last_reset": None, + "max": 10.0, + "mean": 10.0, + "min": 10.0, + "start": now.timestamp(), + } + ], + } + + +@pytest.mark.parametrize( + ( + "requested_new_unit", + "update_statistics_extra", + "error_message", + ), + [ + ("dogs", {"new_unit_class": "cats"}, "Unsupported unit_class: 'cats'"), + ( + "dogs", + {"new_unit_class": "power"}, + "Unsupported unit_of_measurement 'dogs' for unit_class 'power'", + ), + ], +) +@pytest.mark.usefixtures("recorder_mock") +async def test_update_statistics_metadata_error( + hass: HomeAssistant, + requested_new_unit, + update_statistics_extra, + error_message, +) -> None: + """Test removing statistics.""" + now = get_start_time(dt_util.utcnow()) + + units = METRIC_SYSTEM + attributes = POWER_SENSOR_KW_ATTRIBUTES | {"device_class": None} + state = 10 + + hass.config.units = units + await async_setup_component(hass, "sensor", {}) + await async_recorder_block_till_done(hass) + hass.states.async_set( + "sensor.test", state, attributes=attributes, timestamp=now.timestamp() + ) + await async_wait_recording_done(hass) + + do_adhoc_statistics(hass, period="hourly", start=now) + await async_recorder_block_till_done(hass) + + statistic_ids = await async_list_statistic_ids(hass) + assert statistic_ids == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": "kW", + "has_mean": True, + "mean_type": StatisticMeanType.ARITHMETIC, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": "kW", + "unit_class": "power", + } + ] + + with pytest.raises(HomeAssistantError, match=error_message): + async_update_statistics_metadata( + hass, + "sensor.test", + new_unit_of_measurement=requested_new_unit, + **update_statistics_extra, + ) + await async_recorder_block_till_done(hass) + + statistic_ids = await async_list_statistic_ids(hass) + assert statistic_ids == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": "kW", + "has_mean": True, + "mean_type": StatisticMeanType.ARITHMETIC, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": "kW", + "unit_class": "power", + } + ] + + assert statistics_during_period( + hass, + now, + period="5minute", + statistic_ids={"sensor.test"}, + units={"power": "W"}, + ) == { + "sensor.test": [ + { + "end": (now + timedelta(minutes=5)).timestamp(), + "last_reset": None, + "max": 10000.0, + "mean": 10000.0, + "min": 10000.0, + "start": now.timestamp(), + } + ], + } + + @pytest.mark.usefixtures("multiple_start_time_chunk_sizes") @pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"]) @pytest.mark.freeze_time("2022-10-01 00:00:00+00:00") @@ -1337,6 +1569,7 @@ async def test_daily_statistics_sum( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1518,6 +1751,7 @@ async def test_multiple_daily_statistics_sum( "name": "Total imported energy 1", "source": "test", "statistic_id": "test:total_energy_import2", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_metadata2 = { @@ -1526,6 +1760,7 @@ async def test_multiple_daily_statistics_sum( "name": "Total imported energy 2", "source": "test", "statistic_id": "test:total_energy_import1", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1716,6 +1951,7 @@ async def test_weekly_statistics_mean( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1862,6 +2098,7 @@ async def test_weekly_statistics_sum( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -2043,6 +2280,7 @@ async def test_monthly_statistics_sum( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -2372,6 +2610,7 @@ async def test_change( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -2708,6 +2947,7 @@ async def test_change_multiple( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.total_energy_import1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_metadata2 = { @@ -2716,6 +2956,7 @@ async def test_change_multiple( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.total_energy_import2", + "unit_class": "energy", "unit_of_measurement": "kWh", } async_import_statistics(hass, external_metadata1, external_statistics) @@ -3097,6 +3338,7 @@ async def test_change_with_none( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -3651,6 +3893,7 @@ async def test_get_statistics_service( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.total_energy_import1", + "unit_class": "energy", "unit_of_measurement": "kWh", } external_metadata2 = { @@ -3659,6 +3902,7 @@ async def test_get_statistics_service( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.total_energy_import2", + "unit_class": "energy", "unit_of_measurement": "kWh", } async_import_statistics(hass, external_metadata1, external_statistics) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index aa302548517..14787301d3e 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -324,6 +324,7 @@ async def test_statistic_during_period( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.test", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -772,6 +773,7 @@ async def test_statistic_during_period_circular_mean( "name": "Wind direction", "source": "recorder", "statistic_id": "sensor.test", + "unit_class": None, "unit_of_measurement": DEGREE, } @@ -1098,6 +1100,7 @@ async def test_statistic_during_period_hole( "name": "Total imported energy", "source": "recorder", "statistic_id": "sensor.test", + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -1248,6 +1251,7 @@ async def test_statistic_during_period_hole_circular_mean( "name": "Wind direction", "source": "recorder", "statistic_id": "sensor.test", + "unit_class": None, "unit_of_measurement": DEGREE, } @@ -1441,6 +1445,7 @@ async def test_statistic_during_period_partial_overlap( "name": "Total imported energy overlapping", "source": "recorder", "statistic_id": statId, + "unit_class": "energy", "unit_of_measurement": "kWh", } @@ -2729,13 +2734,30 @@ async def test_clear_statistics_time_out(hass_ws_client: WebSocketGenerator) -> @pytest.mark.parametrize( - ("new_unit", "new_unit_class", "new_display_unit"), - [("dogs", None, "dogs"), (None, "unitless", None), ("W", "power", "kW")], + ( + "requested_new_unit", + "websocket_command_extra", + "new_unit", + "new_unit_class", + "new_display_unit", + ), + [ + ("dogs", {}, "dogs", None, "dogs"), + ("dogs", {"unit_class": None}, "dogs", None, "dogs"), + (None, {}, None, "unitless", None), + (None, {"unit_class": "unitless"}, None, "unitless", None), + ("W", {}, "W", "power", "kW"), + ("W", {"unit_class": "power"}, "W", "power", "kW"), + # Note: Display unit is guessed even if unit_class is None + ("W", {"unit_class": None}, "W", None, "kW"), + ], ) @pytest.mark.usefixtures("recorder_mock") async def test_update_statistics_metadata( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, + requested_new_unit, + websocket_command_extra, new_unit, new_unit_class, new_display_unit, @@ -2781,8 +2803,9 @@ async def test_update_statistics_metadata( { "type": "recorder/update_statistics_metadata", "statistic_id": "sensor.test", - "unit_of_measurement": new_unit, + "unit_of_measurement": requested_new_unit, } + | websocket_command_extra ) response = await client.receive_json() assert response["success"] @@ -2830,6 +2853,124 @@ async def test_update_statistics_metadata( } +@pytest.mark.parametrize( + ( + "requested_new_unit", + "websocket_command_extra", + "error_message", + ), + [ + ("dogs", {"unit_class": "cats"}, "Unsupported unit_class: 'cats'"), + ( + "dogs", + {"unit_class": "power"}, + "Unsupported unit_of_measurement 'dogs' for unit_class 'power'", + ), + ], +) +@pytest.mark.usefixtures("recorder_mock") +async def test_update_statistics_metadata_error( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + requested_new_unit, + websocket_command_extra, + error_message, +) -> None: + """Test removing statistics.""" + now = get_start_time(dt_util.utcnow()) + + units = METRIC_SYSTEM + attributes = POWER_SENSOR_KW_ATTRIBUTES | {"device_class": None} + state = 10 + + hass.config.units = units + await async_setup_component(hass, "sensor", {}) + await async_recorder_block_till_done(hass) + hass.states.async_set( + "sensor.test", state, attributes=attributes, timestamp=now.timestamp() + ) + await async_wait_recording_done(hass) + + do_adhoc_statistics(hass, period="hourly", start=now) + await async_recorder_block_till_done(hass) + + client = await hass_ws_client() + + await client.send_json_auto_id({"type": "recorder/list_statistic_ids"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": "kW", + "has_mean": True, + "mean_type": StatisticMeanType.ARITHMETIC, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": "kW", + "unit_class": "power", + } + ] + + await client.send_json_auto_id( + { + "type": "recorder/update_statistics_metadata", + "statistic_id": "sensor.test", + "unit_of_measurement": requested_new_unit, + } + | websocket_command_extra + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"] == { + "code": "home_assistant_error", + "message": error_message, + } + await async_recorder_block_till_done(hass) + + await client.send_json_auto_id({"type": "recorder/list_statistic_ids"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": "kW", + "has_mean": True, + "mean_type": StatisticMeanType.ARITHMETIC, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": "kW", + "unit_class": "power", + } + ] + + await client.send_json_auto_id( + { + "type": "recorder/statistics_during_period", + "start_time": now.isoformat(), + "statistic_ids": ["sensor.test"], + "period": "5minute", + "units": {"power": "W"}, + } + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == { + "sensor.test": [ + { + "end": int((now + timedelta(minutes=5)).timestamp() * 1000), + "last_reset": None, + "max": 10000.0, + "mean": 10000.0, + "min": 10000.0, + "start": int(now.timestamp() * 1000), + } + ], + } + + @pytest.mark.usefixtures("recorder_mock") async def test_update_statistics_metadata_time_out( hass_ws_client: WebSocketGenerator, @@ -2845,6 +2986,7 @@ async def test_update_statistics_metadata_time_out( { "type": "recorder/update_statistics_metadata", "statistic_id": "sensor.test", + "unit_class": None, "unit_of_measurement": "dogs", } ) @@ -3115,6 +3257,27 @@ async def test_change_statistics_unit_errors( await assert_statistic_ids(expected_statistic_ids) await assert_statistics(expected_statistics) + # Try changing an unknown statistic_id + await client.send_json_auto_id( + { + "type": "recorder/change_statistics_unit", + "statistic_id": "sensor.unknown", + "old_unit_of_measurement": "W", + "new_unit_of_measurement": "kW", + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"] == { + "code": "home_assistant_error", + "message": "No metadata found for sensor.unknown", + } + + await async_recorder_block_till_done(hass) + + await assert_statistic_ids(expected_statistic_ids) + await assert_statistics(expected_statistics) + @pytest.mark.usefixtures("recorder_mock") async def test_recorder_info( @@ -3392,6 +3555,7 @@ async def test_get_statistics_metadata( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_gas", + "unit_class": unit_class, "unit_of_measurement": unit, } @@ -3483,6 +3647,21 @@ async def test_get_statistics_metadata( ] +@pytest.mark.parametrize( + ("external_metadata_extra", "unit_1", "unit_2", "unit_3", "expected_unit_class"), + [ + ({}, "kWh", "kWh", "kWh", "energy"), + ({"unit_class": "energy"}, "kWh", "kWh", "kWh", "energy"), + ({}, "cats", "cats", "cats", None), + ({"unit_class": None}, "cats", "cats", "cats", None), + # Note: The import API does not unit convert and does not block changing unit, + # we may want to address that + ({}, "kWh", "Wh", "MWh", "energy"), + ({"unit_class": "energy"}, "kWh", "Wh", "MWh", "energy"), + ({}, "cats", "dogs", "horses", None), + ({"unit_class": None}, "cats", "dogs", "horses", None), + ], +) @pytest.mark.parametrize( ("source", "statistic_id"), [ @@ -3495,8 +3674,13 @@ async def test_import_statistics( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, - source, - statistic_id, + external_metadata_extra: dict[str, str], + unit_1: str, + unit_2: str, + unit_3: str, + expected_unit_class: str | None, + source: str, + statistic_id: str, ) -> None: """Test importing statistics.""" client = await hass_ws_client() @@ -3527,8 +3711,8 @@ async def test_import_statistics( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, - "unit_of_measurement": "kWh", - } + "unit_of_measurement": unit_1, + } | external_metadata_extra await client.send_json_auto_id( { @@ -3566,15 +3750,15 @@ async def test_import_statistics( statistic_ids = list_statistic_ids(hass) assert statistic_ids == [ { - "display_unit_of_measurement": "kWh", + "display_unit_of_measurement": unit_1, "has_mean": False, "mean_type": StatisticMeanType.NONE, "has_sum": True, "statistic_id": statistic_id, "name": "Total imported energy", "source": source, - "statistics_unit_of_measurement": "kWh", - "unit_class": "energy", + "statistics_unit_of_measurement": unit_1, + "unit_class": expected_unit_class, } ] metadata = get_metadata(hass, statistic_ids={statistic_id}) @@ -3588,7 +3772,8 @@ async def test_import_statistics( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, - "unit_of_measurement": "kWh", + "unit_class": expected_unit_class, + "unit_of_measurement": unit_1, }, ) } @@ -3622,7 +3807,7 @@ async def test_import_statistics( await client.send_json_auto_id( { "type": "recorder/import_statistics", - "metadata": imported_metadata, + "metadata": imported_metadata | {"unit_of_measurement": unit_2}, "stats": [external_statistics], } ) @@ -3652,6 +3837,36 @@ async def test_import_statistics( }, ] } + statistic_ids = list_statistic_ids(hass) + assert statistic_ids == [ + { + "display_unit_of_measurement": unit_2, + "has_mean": False, + "mean_type": StatisticMeanType.NONE, + "has_sum": True, + "statistic_id": statistic_id, + "name": "Total imported energy", + "source": source, + "statistics_unit_of_measurement": unit_2, + "unit_class": expected_unit_class, + } + ] + metadata = get_metadata(hass, statistic_ids={statistic_id}) + assert metadata == { + statistic_id: ( + 1, + { + "has_mean": False, + "mean_type": StatisticMeanType.NONE, + "has_sum": True, + "name": "Total imported energy", + "source": source, + "statistic_id": statistic_id, + "unit_class": expected_unit_class, + "unit_of_measurement": unit_2, + }, + ) + } # Update the previously inserted statistics external_statistics = { @@ -3667,7 +3882,7 @@ async def test_import_statistics( await client.send_json_auto_id( { "type": "recorder/import_statistics", - "metadata": imported_metadata, + "metadata": imported_metadata | {"unit_of_measurement": unit_3}, "stats": [external_statistics], } ) @@ -3697,8 +3912,140 @@ async def test_import_statistics( }, ] } + statistic_ids = list_statistic_ids(hass) + assert statistic_ids == [ + { + "display_unit_of_measurement": unit_3, + "has_mean": False, + "mean_type": StatisticMeanType.NONE, + "has_sum": True, + "statistic_id": statistic_id, + "name": "Total imported energy", + "source": source, + "statistics_unit_of_measurement": unit_3, + "unit_class": expected_unit_class, + } + ] + metadata = get_metadata(hass, statistic_ids={statistic_id}) + assert metadata == { + statistic_id: ( + 1, + { + "has_mean": False, + "mean_type": StatisticMeanType.NONE, + "has_sum": True, + "name": "Total imported energy", + "source": source, + "statistic_id": statistic_id, + "unit_class": expected_unit_class, + "unit_of_measurement": unit_3, + }, + ) + } +@pytest.mark.parametrize( + ("unit_class", "unit", "error_message"), + [ + ("dogs", "cats", "Unsupported unit_class: 'dogs'"), + ( + "energy", + "cats", + "Unsupported unit_of_measurement 'cats' for unit_class 'energy'", + ), + ], +) +@pytest.mark.parametrize( + ("source", "statistic_id"), + [ + ("test", "test:total_energy_import"), + ("recorder", "sensor.total_energy_import"), + ], +) +async def test_import_statistics_with_error( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + caplog: pytest.LogCaptureFixture, + unit_class: str, + unit: str, + error_message: str, + source, + statistic_id, +) -> None: + """Test importing statistics.""" + client = await hass_ws_client() + + assert "Compiling statistics for" not in caplog.text + assert "Statistics already compiled" not in caplog.text + + zero = dt_util.utcnow() + period1 = zero.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) + period2 = zero.replace(minute=0, second=0, microsecond=0) + timedelta(hours=2) + + imported_statistics1 = { + "start": period1.isoformat(), + "last_reset": None, + "state": 0, + "sum": 2, + } + imported_statistics2 = { + "start": period2.isoformat(), + "last_reset": None, + "state": 1, + "sum": 3, + } + + imported_metadata = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": source, + "statistic_id": statistic_id, + "unit_class": unit_class, + "unit_of_measurement": unit, + } + + await client.send_json_auto_id( + { + "type": "recorder/import_statistics", + "metadata": imported_metadata, + "stats": [imported_statistics1, imported_statistics2], + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"] == { + "code": "home_assistant_error", + "message": error_message, + } + + await async_wait_recording_done(hass) + stats = statistics_during_period( + hass, zero, period="hour", statistic_ids={statistic_id} + ) + assert stats == {} + statistic_ids = list_statistic_ids(hass) + assert statistic_ids == [] + metadata = get_metadata(hass, statistic_ids={statistic_id}) + assert metadata == {} + last_stats = get_last_statistics( + hass, + 1, + statistic_id, + True, + {"last_reset", "max", "mean", "min", "state", "sum"}, + ) + assert last_stats == {} + + +@pytest.mark.parametrize( + ("external_metadata_extra"), + [ + {}, + {"unit_class": "energy"}, + ], +) @pytest.mark.parametrize( ("source", "statistic_id"), [ @@ -3711,6 +4058,7 @@ async def test_adjust_sum_statistics_energy( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, + external_metadata_extra: dict[str, str], source, statistic_id, ) -> None: @@ -3744,7 +4092,7 @@ async def test_adjust_sum_statistics_energy( "source": source, "statistic_id": statistic_id, "unit_of_measurement": "kWh", - } + } | external_metadata_extra await client.send_json_auto_id( { @@ -3808,6 +4156,7 @@ async def test_adjust_sum_statistics_energy( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, + "unit_class": "energy", "unit_of_measurement": "kWh", }, ) @@ -3894,6 +4243,13 @@ async def test_adjust_sum_statistics_energy( } +@pytest.mark.parametrize( + ("external_metadata_extra"), + [ + {}, + {"unit_class": "volume"}, + ], +) @pytest.mark.parametrize( ("source", "statistic_id"), [ @@ -3906,6 +4262,7 @@ async def test_adjust_sum_statistics_gas( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, + external_metadata_extra: dict[str, str], source, statistic_id, ) -> None: @@ -3939,7 +4296,7 @@ async def test_adjust_sum_statistics_gas( "source": source, "statistic_id": statistic_id, "unit_of_measurement": "m³", - } + } | external_metadata_extra await client.send_json_auto_id( { @@ -4003,6 +4360,7 @@ async def test_adjust_sum_statistics_gas( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, + "unit_class": "volume", "unit_of_measurement": "m³", }, ) @@ -4150,6 +4508,7 @@ async def test_adjust_sum_statistics_errors( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, + "unit_class": unit_class, "unit_of_measurement": statistic_unit, } @@ -4216,6 +4575,7 @@ async def test_adjust_sum_statistics_errors( "name": "Total imported energy", "source": source, "statistic_id": statistic_id, + "unit_class": unit_class, "unit_of_measurement": state_unit, }, ) @@ -4312,6 +4672,7 @@ async def test_import_statistics_with_last_reset( "name": "Total imported energy", "source": "test", "statistic_id": "test:total_energy_import", + "unit_class": "energy", "unit_of_measurement": "kWh", } diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index 6afce0d3eb5..5dadc5bd4ed 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -241,10 +241,25 @@ async def assert_validation_result( ), [ (None, "%", "%", "%", "unitless", 13.050847, -10, 30), + (None, "ppm", "ppm", "ppm", "unitless", 13.050847, -10, 30), + (None, "g/m³", "g/m³", "g/m³", "concentration", 13.050847, -10, 30), + (None, "mg/m³", "mg/m³", "mg/m³", "concentration", 13.050847, -10, 30), ("area", "m²", "m²", "m²", "area", 13.050847, -10, 30), ("area", "mi²", "mi²", "mi²", "area", 13.050847, -10, 30), ("battery", "%", "%", "%", "unitless", 13.050847, -10, 30), ("battery", None, None, None, "unitless", 13.050847, -10, 30), + # We can't yet convert carbon_monoxide + ( + "carbon_monoxide", + "mg/m³", + "mg/m³", + "mg/m³", + "concentration", + 13.050847, + -10, + 30, + ), + ("carbon_monoxide", "ppm", "ppm", "ppm", "unitless", 13.050847, -10, 30), ("distance", "m", "m", "m", "distance", 13.050847, -10, 30), ("distance", "mi", "mi", "mi", "distance", 13.050847, -10, 30), ("humidity", "%", "%", "%", "unitless", 13.050847, -10, 30), @@ -3261,6 +3276,9 @@ async def test_list_statistic_ids_unsupported( (None, "ft³", "ft3", "volume", 13.050847, -10, 30), (None, "ft³/min", "ft³/m", "volume_flow_rate", 13.050847, -10, 30), (None, "m³", "m3", "volume", 13.050847, -10, 30), + # Can't yet convert carbon_monoxide + ("carbon_monoxide", "ppm", "mg/m³", "unitless", 13.050847, -10, 30), + ("carbon_monoxide", "mg/m³", "ppm", "concentration", 13.050847, -10, 30), ], ) async def test_compile_hourly_statistics_changing_units_1( @@ -3589,17 +3607,30 @@ async def test_compile_hourly_statistics_changing_units_3( @pytest.mark.parametrize( - ("state_unit_1", "state_unit_2", "unit_class", "mean", "min", "max", "factor"), + ( + "device_class", + "state_unit_1", + "state_unit_2", + "unit_class", + "mean", + "min", + "max", + "factor", + ), [ - (None, "%", "unitless", 13.050847, -10, 30, 100), - ("%", None, "unitless", 13.050847, -10, 30, 0.01), - ("W", "kW", "power", 13.050847, -10, 30, 0.001), - ("kW", "W", "power", 13.050847, -10, 30, 1000), + (None, None, "%", "unitless", 13.050847, -10, 30, 100), + (None, None, "ppm", "unitless", 13.050847, -10, 30, 1000000), + (None, "g/m³", "mg/m³", "concentration", 13.050847, -10, 30, 1000), + (None, "mg/m³", "g/m³", "concentration", 13.050847, -10, 30, 0.001), + (None, "%", None, "unitless", 13.050847, -10, 30, 0.01), + (None, "W", "kW", "power", 13.050847, -10, 30, 0.001), + (None, "kW", "W", "power", 13.050847, -10, 30, 1000), ], ) async def test_compile_hourly_statistics_convert_units_1( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, + device_class, state_unit_1, state_unit_2, unit_class, @@ -3617,7 +3648,7 @@ async def test_compile_hourly_statistics_convert_units_1( # Wait for the sensor recorder platform to be added await async_recorder_block_till_done(hass) attributes = { - "device_class": None, + "device_class": device_class, "state_class": "measurement", "unit_of_measurement": state_unit_1, } @@ -4441,6 +4472,7 @@ async def test_compile_hourly_statistics_changing_state_class( "name": None, "source": "recorder", "statistic_id": "sensor.test1", + "unit_class": unit_class, "unit_of_measurement": None, }, ) @@ -4485,6 +4517,7 @@ async def test_compile_hourly_statistics_changing_state_class( "name": None, "source": "recorder", "statistic_id": "sensor.test1", + "unit_class": unit_class, "unit_of_measurement": None, }, ) @@ -5936,6 +5969,7 @@ async def test_validate_statistics_other_domain( "name": None, "source": RECORDER_DOMAIN, "statistic_id": "number.test", + "unit_class": None, "unit_of_measurement": None, } statistics: StatisticData = {