From cd6697615f74aa7bf6d41e3ec4d7d9bad4ffa2e4 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 21 Sep 2022 18:08:53 +0200 Subject: [PATCH] Validate units when importing statistics (#78891) --- .../components/recorder/statistics.py | 40 ++++++++++++++++++- tests/components/demo/test_init.py | 1 + tests/components/energy/test_websocket_api.py | 9 +++++ tests/components/recorder/test_statistics.py | 23 +++++++++++ .../components/recorder/test_websocket_api.py | 28 ++++++++++++- 5 files changed, 98 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 1fbd8f192ef..5f780b5f0c7 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -195,6 +195,17 @@ STATISTIC_UNIT_TO_UNIT_CLASS: dict[str | None, str] = { VOLUME_CUBIC_METERS: "volume", } +STATISTIC_UNIT_TO_VALID_UNITS: dict[str | None, Iterable[str | None]] = { + ENERGY_KILO_WATT_HOUR: [ + ENERGY_KILO_WATT_HOUR, + ENERGY_MEGA_WATT_HOUR, + ENERGY_WATT_HOUR, + ], + POWER_WATT: power_util.VALID_UNITS, + PRESSURE_PA: pressure_util.VALID_UNITS, + TEMP_CELSIUS: temperature_util.VALID_UNITS, + VOLUME_CUBIC_METERS: volume_util.VALID_UNITS, +} # Convert energy power, pressure, temperature and volume statistics from the # normalized unit used for statistics to the unit configured by the user @@ -238,8 +249,17 @@ def _get_statistic_to_display_unit_converter( ) is None: return no_conversion + display_unit: str | None unit_class = STATISTIC_UNIT_TO_UNIT_CLASS[statistic_unit] - display_unit = requested_units.get(unit_class) if requested_units else state_unit + if requested_units and unit_class in requested_units: + display_unit = requested_units[unit_class] + else: + display_unit = state_unit + + if display_unit not in STATISTIC_UNIT_TO_VALID_UNITS[statistic_unit]: + # Guard against invalid state unit in the DB + return no_conversion + return partial(convert_fn, display_unit) @@ -1503,6 +1523,16 @@ def _async_import_statistics( get_instance(hass).async_import_statistics(metadata, statistics) +def _validate_units(statistics_unit: str | None, state_unit: str | None) -> None: + """Raise if the statistics unit and state unit are not compatible.""" + if statistics_unit == state_unit: + return + if (valid_units := STATISTIC_UNIT_TO_VALID_UNITS.get(statistics_unit)) is None: + raise HomeAssistantError(f"Invalid units {statistics_unit},{state_unit}") + if state_unit not in valid_units: + raise HomeAssistantError(f"Invalid units {statistics_unit},{state_unit}") + + @callback def async_import_statistics( hass: HomeAssistant, @@ -1520,6 +1550,10 @@ def async_import_statistics( if not metadata["source"] or metadata["source"] != DOMAIN: raise HomeAssistantError("Invalid source") + _validate_units( + metadata["unit_of_measurement"], metadata["state_unit_of_measurement"] + ) + _async_import_statistics(hass, metadata, statistics) @@ -1542,6 +1576,10 @@ def async_add_external_statistics( if not metadata["source"] or metadata["source"] != domain: raise HomeAssistantError("Invalid source") + _validate_units( + metadata["unit_of_measurement"], metadata["state_unit_of_measurement"] + ) + _async_import_statistics(hass, metadata, statistics) diff --git a/tests/components/demo/test_init.py b/tests/components/demo/test_init.py index 79da28d8abd..934321a0ed8 100644 --- a/tests/components/demo/test_init.py +++ b/tests/components/demo/test_init.py @@ -96,6 +96,7 @@ async def test_demo_statistics_growth(hass, recorder_mock): metadata = { "source": DOMAIN, "name": "Energy consumption 1", + "state_unit_of_measurement": "m³", "statistic_id": statistic_id, "unit_of_measurement": "m³", "has_mean": False, diff --git a/tests/components/energy/test_websocket_api.py b/tests/components/energy/test_websocket_api.py index 1e9f89ac726..8adc091305c 100644 --- a/tests/components/energy/test_websocket_api.py +++ b/tests/components/energy/test_websocket_api.py @@ -343,6 +343,7 @@ async def test_fossil_energy_consumption_no_co2(hass, hass_ws_client, recorder_m "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_1", "unit_of_measurement": "kWh", } @@ -377,6 +378,7 @@ async def test_fossil_energy_consumption_no_co2(hass, hass_ws_client, recorder_m "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_2", "unit_of_measurement": "kWh", } @@ -504,6 +506,7 @@ async def test_fossil_energy_consumption_hole(hass, hass_ws_client, recorder_moc "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_1", "unit_of_measurement": "kWh", } @@ -538,6 +541,7 @@ async def test_fossil_energy_consumption_hole(hass, hass_ws_client, recorder_moc "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_2", "unit_of_measurement": "kWh", } @@ -663,6 +667,7 @@ async def test_fossil_energy_consumption_no_data(hass, hass_ws_client, recorder_ "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_1", "unit_of_measurement": "kWh", } @@ -697,6 +702,7 @@ async def test_fossil_energy_consumption_no_data(hass, hass_ws_client, recorder_ "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_2", "unit_of_measurement": "kWh", } @@ -813,6 +819,7 @@ async def test_fossil_energy_consumption(hass, hass_ws_client, recorder_mock): "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_1", "unit_of_measurement": "kWh", } @@ -847,6 +854,7 @@ async def test_fossil_energy_consumption(hass, hass_ws_client, recorder_mock): "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_2", "unit_of_measurement": "kWh", } @@ -877,6 +885,7 @@ async def test_fossil_energy_consumption(hass, hass_ws_client, recorder_mock): "has_sum": False, "name": "Fossil percentage", "source": "test", + "state_unit_of_measurement": "%", "statistic_id": "test:fossil_percentage", "unit_of_measurement": "%", } diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index f2de32a443c..4fc98333bf4 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -741,6 +741,7 @@ def test_external_statistics_errors(hass_recorder, caplog): "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import", "unit_of_measurement": "kWh", } @@ -804,6 +805,16 @@ def test_external_statistics_errors(hass_recorder, caplog): assert list_statistic_ids(hass) == [] assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + # Attempt to insert statistics with an invalid unit combination + external_metadata = {**_external_metadata, "state_unit_of_measurement": "cats"} + external_statistics = {**_external_statistics} + with pytest.raises(HomeAssistantError): + async_add_external_statistics(hass, external_metadata, (external_statistics,)) + wait_recording_done(hass) + assert statistics_during_period(hass, zero, period="hour") == {} + assert list_statistic_ids(hass) == [] + assert get_metadata(hass, statistic_ids=("test:total_energy_import",)) == {} + def test_import_statistics_errors(hass_recorder, caplog): """Test validation of imported statistics.""" @@ -828,6 +839,7 @@ def test_import_statistics_errors(hass_recorder, caplog): "has_sum": True, "name": "Total imported energy", "source": "recorder", + "state_unit_of_measurement": "kWh", "statistic_id": "sensor.total_energy_import", "unit_of_measurement": "kWh", } @@ -891,6 +903,16 @@ def test_import_statistics_errors(hass_recorder, caplog): assert list_statistic_ids(hass) == [] assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + # Attempt to insert statistics with an invalid unit combination + external_metadata = {**_external_metadata, "state_unit_of_measurement": "cats"} + external_statistics = {**_external_statistics} + with pytest.raises(HomeAssistantError): + async_import_statistics(hass, external_metadata, (external_statistics,)) + wait_recording_done(hass) + assert statistics_during_period(hass, zero, period="hour") == {} + assert list_statistic_ids(hass) == [] + assert get_metadata(hass, statistic_ids=("sensor.total_energy_import",)) == {} + @pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"]) @pytest.mark.freeze_time("2021-08-01 00:00:00+00:00") @@ -940,6 +962,7 @@ def test_monthly_statistics(hass_recorder, caplog, timezone): "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import", "unit_of_measurement": "kWh", } diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 6f8b2be7d58..e8214f0209e 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -1100,6 +1100,7 @@ async def test_get_statistics_metadata( "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": unit, "statistic_id": "test:total_gas", "unit_of_measurement": unit, } @@ -1107,6 +1108,29 @@ async def test_get_statistics_metadata( async_add_external_statistics( hass, external_energy_metadata_1, external_energy_statistics_1 ) + await async_wait_recording_done(hass) + + await client.send_json( + { + "id": 2, + "type": "recorder/get_statistics_metadata", + "statistic_ids": ["test:total_gas"], + } + ) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + { + "statistic_id": "test:total_gas", + "display_unit_of_measurement": unit, + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistics_unit_of_measurement": unit, + "unit_class": unit_class, + } + ] hass.states.async_set("sensor.test", 10, attributes=attributes) await async_wait_recording_done(hass) @@ -1116,7 +1140,7 @@ async def test_get_statistics_metadata( await client.send_json( { - "id": 2, + "id": 3, "type": "recorder/get_statistics_metadata", "statistic_ids": ["sensor.test"], } @@ -1144,7 +1168,7 @@ async def test_get_statistics_metadata( await client.send_json( { - "id": 3, + "id": 4, "type": "recorder/get_statistics_metadata", "statistic_ids": ["sensor.test"], }