diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 2583b2f53b6..b6af3d60d73 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -196,6 +196,7 @@ def list_statistic_ids( ) -> list[dict[str, str] | None]: """Return statistic_ids and meta data.""" units = hass.config.units + statistic_ids = {} with session_scope(hass=hass) as session: metadata = _get_metadata(hass, session, None, statistic_type) @@ -203,7 +204,26 @@ def list_statistic_ids( unit = _configured_unit(meta["unit_of_measurement"], units) meta["unit_of_measurement"] = unit - return list(metadata.values()) + statistic_ids = { + meta["statistic_id"]: meta["unit_of_measurement"] + for meta in metadata.values() + } + + for platform in hass.data[DOMAIN].values(): + if not hasattr(platform, "list_statistic_ids"): + continue + platform_statistic_ids = platform.list_statistic_ids(hass, statistic_type) + + for statistic_id, unit in platform_statistic_ids.items(): + unit = _configured_unit(unit, units) + platform_statistic_ids[statistic_id] = unit + + statistic_ids = {**statistic_ids, **platform_statistic_ids} + + return [ + {"statistic_id": _id, "unit_of_measurement": unit} + for _id, unit in statistic_ids.items() + ] def statistics_during_period( diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index 485674ec728..b5a38cfeec1 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -40,7 +40,7 @@ import homeassistant.util.dt as dt_util import homeassistant.util.pressure as pressure_util import homeassistant.util.temperature as temperature_util -from . import DOMAIN +from . import ATTR_LAST_RESET, DOMAIN _LOGGER = logging.getLogger(__name__) @@ -280,3 +280,36 @@ def compile_statistics( result[entity_id]["stat"] = stat return result + + +def list_statistic_ids(hass: HomeAssistant, statistic_type: str | None = None) -> dict: + """Return statistic_ids and meta data.""" + entities = _get_entities(hass) + + statistic_ids = {} + + for entity_id, device_class in entities: + provided_statistics = DEVICE_CLASS_STATISTICS[device_class] + + if statistic_type is not None and statistic_type not in provided_statistics: + continue + + state = hass.states.get(entity_id) + assert state + + if "sum" in provided_statistics and ATTR_LAST_RESET not in state.attributes: + continue + + native_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + + if device_class not in UNIT_CONVERSIONS: + statistic_ids[entity_id] = native_unit + continue + + if native_unit not in UNIT_CONVERSIONS[device_class]: + continue + + statistics_unit = DEVICE_CLASS_UNITS[device_class] + statistic_ids[entity_id] = statistics_unit + + return statistic_ids diff --git a/tests/components/history/test_init.py b/tests/components/history/test_init.py index df4a59a372c..8d78f80c634 100644 --- a/tests/components/history/test_init.py +++ b/tests/components/history/test_init.py @@ -988,11 +988,6 @@ async def test_list_statistic_ids(hass, hass_ws_client, units, attributes, unit) await async_setup_component(hass, "history", {"history": {}}) await async_setup_component(hass, "sensor", {}) await hass.async_add_executor_job(hass.data[recorder.DATA_INSTANCE].block_till_done) - hass.states.async_set("sensor.test", 10, attributes=attributes) - await hass.async_block_till_done() - - await hass.async_add_executor_job(trigger_db_commit, hass) - await hass.async_block_till_done() client = await hass_ws_client() await client.send_json({"id": 1, "type": "history/list_statistic_ids"}) @@ -1000,8 +995,11 @@ async def test_list_statistic_ids(hass, hass_ws_client, units, attributes, unit) assert response["success"] assert response["result"] == [] - hass.data[recorder.DATA_INSTANCE].do_adhoc_statistics(period="hourly", start=now) - await hass.async_add_executor_job(hass.data[recorder.DATA_INSTANCE].block_till_done) + hass.states.async_set("sensor.test", 10, attributes=attributes) + await hass.async_block_till_done() + + await hass.async_add_executor_job(trigger_db_commit, hass) + await hass.async_block_till_done() await client.send_json({"id": 2, "type": "history/list_statistic_ids"}) response = await client.receive_json() @@ -1010,15 +1008,28 @@ async def test_list_statistic_ids(hass, hass_ws_client, units, attributes, unit) {"statistic_id": "sensor.test", "unit_of_measurement": unit} ] + hass.data[recorder.DATA_INSTANCE].do_adhoc_statistics(period="hourly", start=now) + await hass.async_add_executor_job(hass.data[recorder.DATA_INSTANCE].block_till_done) + # Remove the state, statistics will now be fetched from the database + hass.states.async_remove("sensor.test") + await hass.async_block_till_done() + + await client.send_json({"id": 3, "type": "history/list_statistic_ids"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + {"statistic_id": "sensor.test", "unit_of_measurement": unit} + ] + await client.send_json( - {"id": 3, "type": "history/list_statistic_ids", "statistic_type": "dogs"} + {"id": 4, "type": "history/list_statistic_ids", "statistic_type": "dogs"} ) response = await client.receive_json() assert response["success"] assert response["result"] == [] await client.send_json( - {"id": 4, "type": "history/list_statistic_ids", "statistic_type": "mean"} + {"id": 5, "type": "history/list_statistic_ids", "statistic_type": "mean"} ) response = await client.receive_json() assert response["success"] @@ -1027,7 +1038,7 @@ async def test_list_statistic_ids(hass, hass_ws_client, units, attributes, unit) ] await client.send_json( - {"id": 5, "type": "history/list_statistic_ids", "statistic_type": "sum"} + {"id": 6, "type": "history/list_statistic_ids", "statistic_type": "sum"} ) response = await client.receive_json() assert response["success"] diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index 99ede396381..998cc93e629 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -120,12 +120,6 @@ def test_compile_hourly_statistics_unsupported(hass_recorder, caplog, attributes attributes.pop("state_class") _, _states = record_states(hass, zero, "sensor.test5", attributes) states = {**states, **_states} - attributes["state_class"] = "measurement" - _, _states = record_states(hass, zero, "sensor.test6", attributes) - states = {**states, **_states} - attributes["state_class"] = "unsupported" - _, _states = record_states(hass, zero, "sensor.test7", attributes) - states = {**states, **_states} hist = history.get_significant_states(hass, zero, four) assert dict(states) == dict(hist) @@ -626,6 +620,79 @@ def test_compile_hourly_statistics_fails(hass_recorder, caplog): assert "Error while processing event StatisticsTask" in caplog.text +@pytest.mark.parametrize( + "device_class,unit,native_unit,statistic_type", + [ + ("battery", "%", "%", "mean"), + ("battery", None, None, "mean"), + ("energy", "Wh", "kWh", "sum"), + ("energy", "kWh", "kWh", "sum"), + ("humidity", "%", "%", "mean"), + ("humidity", None, None, "mean"), + ("monetary", "USD", "USD", "sum"), + ("monetary", "None", "None", "sum"), + ("pressure", "Pa", "Pa", "mean"), + ("pressure", "hPa", "Pa", "mean"), + ("pressure", "mbar", "Pa", "mean"), + ("pressure", "inHg", "Pa", "mean"), + ("pressure", "psi", "Pa", "mean"), + ("temperature", "°C", "°C", "mean"), + ("temperature", "°F", "°C", "mean"), + ], +) +def test_list_statistic_ids( + hass_recorder, caplog, device_class, unit, native_unit, statistic_type +): + """Test listing future statistic ids.""" + hass = hass_recorder() + setup_component(hass, "sensor", {}) + attributes = { + "device_class": device_class, + "last_reset": 0, + "state_class": "measurement", + "unit_of_measurement": unit, + } + hass.states.set("sensor.test1", 0, attributes=attributes) + statistic_ids = list_statistic_ids(hass) + assert statistic_ids == [ + {"statistic_id": "sensor.test1", "unit_of_measurement": native_unit} + ] + for stat_type in ["mean", "sum", "dogs"]: + statistic_ids = list_statistic_ids(hass, statistic_type=stat_type) + if statistic_type == stat_type: + assert statistic_ids == [ + {"statistic_id": "sensor.test1", "unit_of_measurement": native_unit} + ] + else: + assert statistic_ids == [] + + +@pytest.mark.parametrize( + "_attributes", + [{**ENERGY_SENSOR_ATTRIBUTES, "last_reset": 0}, TEMPERATURE_SENSOR_ATTRIBUTES], +) +def test_list_statistic_ids_unsupported(hass_recorder, caplog, _attributes): + """Test listing future statistic ids for unsupported sensor.""" + hass = hass_recorder() + setup_component(hass, "sensor", {}) + attributes = dict(_attributes) + hass.states.set("sensor.test1", 0, attributes=attributes) + if "last_reset" in attributes: + attributes.pop("unit_of_measurement") + hass.states.set("last_reset.test2", 0, attributes=attributes) + attributes = dict(_attributes) + if "unit_of_measurement" in attributes: + attributes["unit_of_measurement"] = "invalid" + hass.states.set("sensor.test3", 0, attributes=attributes) + attributes.pop("unit_of_measurement") + hass.states.set("sensor.test4", 0, attributes=attributes) + attributes = dict(_attributes) + attributes["state_class"] = "invalid" + hass.states.set("sensor.test5", 0, attributes=attributes) + attributes.pop("state_class") + hass.states.set("sensor.test6", 0, attributes=attributes) + + def record_states(hass, zero, entity_id, attributes): """Record some test states.