diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 7351657f64c..b473dead17b 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -329,6 +329,13 @@ class ClearStatisticsTask(NamedTuple): statistic_ids: list[str] +class UpdateStatisticsMetadataTask(NamedTuple): + """Object to store statistics_id and unit for update of statistics metadata.""" + + statistic_id: str + unit_of_measurement: str | None + + class PurgeTask(NamedTuple): """Object to store information about purge task.""" @@ -581,6 +588,11 @@ class Recorder(threading.Thread): """Clear statistics for a list of statistic_ids.""" self.queue.put(ClearStatisticsTask(statistic_ids)) + @callback + def async_update_statistics_metadata(self, statistic_id, unit_of_measurement): + """Update statistics metadata for a statistic_id.""" + self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement)) + @callback def _async_setup_periodic_tasks(self): """Prepare periodic tasks.""" @@ -777,6 +789,11 @@ class Recorder(threading.Thread): if isinstance(event, ClearStatisticsTask): statistics.clear_statistics(self, event.statistic_ids) return + if isinstance(event, UpdateStatisticsMetadataTask): + statistics.update_statistics_metadata( + self, event.statistic_id, event.unit_of_measurement + ) + return if isinstance(event, WaitTask): self._queue_watch.set() return diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 1b864992010..2b4775e2412 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -466,6 +466,16 @@ def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: ).delete(synchronize_session=False) +def update_statistics_metadata( + instance: Recorder, statistic_id: str, unit_of_measurement: str | None +) -> None: + """Update statistics metadata for a statistic_id.""" + with session_scope(session=instance.get_session()) as session: # type: ignore + session.query(StatisticsMeta).filter( + StatisticsMeta.statistic_id == statistic_id + ).update({StatisticsMeta.unit_of_measurement: unit_of_measurement}) + + def list_statistic_ids( hass: HomeAssistant, statistic_type: Literal["mean"] | Literal["sum"] | None = None, diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 6d90f48eb76..ba77692fe8e 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -1,4 +1,6 @@ """The Energy websocket API.""" +from __future__ import annotations + import voluptuous as vol from homeassistant.components import websocket_api @@ -13,6 +15,7 @@ def async_setup(hass: HomeAssistant) -> None: """Set up the recorder websocket API.""" websocket_api.async_register_command(hass, ws_validate_statistics) websocket_api.async_register_command(hass, ws_clear_statistics) + websocket_api.async_register_command(hass, ws_update_statistics_metadata) @websocket_api.websocket_command( @@ -50,3 +53,22 @@ def ws_clear_statistics( """ hass.data[DATA_INSTANCE].async_clear_statistics(msg["statistic_ids"]) connection.send_result(msg["id"]) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "recorder/update_statistics_metadata", + vol.Required("statistic_id"): str, + vol.Required("unit_of_measurement"): vol.Any(str, None), + } +) +@callback +def ws_update_statistics_metadata( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """Update statistics metadata for a statistic_id.""" + hass.data[DATA_INSTANCE].async_update_statistics_metadata( + msg["statistic_id"], msg["unit_of_measurement"] + ) + connection.send_result(msg["id"]) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index d9e546f6894..9e856fc6b85 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -377,3 +377,55 @@ async def test_clear_statistics(hass, hass_ws_client): response = await client.receive_json() assert response["success"] assert response["result"] == {"sensor.test2": expected_response["sensor.test2"]} + + +@pytest.mark.parametrize("new_unit", ["dogs", None]) +async def test_update_statistics_metadata(hass, hass_ws_client, new_unit): + """Test removing statistics.""" + now = dt_util.utcnow() + + units = METRIC_SYSTEM + attributes = POWER_SENSOR_ATTRIBUTES + state = 10 + + hass.config.units = units + await hass.async_add_executor_job(init_recorder_component, hass) + await async_setup_component(hass, "history", {}) + await async_setup_component(hass, "sensor", {}) + await hass.async_add_executor_job(hass.data[DATA_INSTANCE].block_till_done) + hass.states.async_set("sensor.test", state, attributes=attributes) + await hass.async_block_till_done() + + await hass.async_add_executor_job(trigger_db_commit, hass) + await hass.async_block_till_done() + + hass.data[DATA_INSTANCE].do_adhoc_statistics(period="hourly", start=now) + await hass.async_add_executor_job(hass.data[DATA_INSTANCE].block_till_done) + + client = await hass_ws_client() + + await client.send_json({"id": 1, "type": "history/list_statistic_ids"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + {"statistic_id": "sensor.test", "unit_of_measurement": "W"} + ] + + await client.send_json( + { + "id": 2, + "type": "recorder/update_statistics_metadata", + "statistic_id": "sensor.test", + "unit_of_measurement": new_unit, + } + ) + response = await client.receive_json() + assert response["success"] + await hass.async_add_executor_job(hass.data[DATA_INSTANCE].block_till_done) + + await client.send_json({"id": 3, "type": "history/list_statistic_ids"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + {"statistic_id": "sensor.test", "unit_of_measurement": new_unit} + ]