From 89b7bf21088934676d4828179e2fc4aa0ebcccac Mon Sep 17 00:00:00 2001 From: dougiteixeira <31328123+dougiteixeira@users.noreply.github.com> Date: Sat, 22 Jun 2024 10:03:43 -0300 Subject: [PATCH] Add the ability to change the source entity of the Derivative helper (#119754) --- .../components/derivative/__init__.py | 10 ++- .../components/derivative/config_flow.py | 78 +++++++++++++---- homeassistant/components/derivative/sensor.py | 32 ++----- .../components/derivative/test_config_flow.py | 23 +++-- tests/components/derivative/test_init.py | 87 ++++++++++++++++++- 5 files changed, 181 insertions(+), 49 deletions(-) diff --git a/homeassistant/components/derivative/__init__.py b/homeassistant/components/derivative/__init__.py index 2b365e96244..5117663f3c5 100644 --- a/homeassistant/components/derivative/__init__.py +++ b/homeassistant/components/derivative/__init__.py @@ -3,12 +3,20 @@ from __future__ import annotations from homeassistant.config_entries import ConfigEntry -from homeassistant.const import Platform +from homeassistant.const import CONF_SOURCE, Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers.device import ( + async_remove_stale_devices_links_keep_entity_device, +) async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Derivative from a config entry.""" + + async_remove_stale_devices_links_keep_entity_device( + hass, entry.entry_id, entry.options[CONF_SOURCE] + ) + await hass.config_entries.async_forward_entry_setups(entry, (Platform.SENSOR,)) entry.async_on_unload(entry.add_update_listener(config_entry_update_listener)) return True diff --git a/homeassistant/components/derivative/config_flow.py b/homeassistant/components/derivative/config_flow.py index e15741ce9cf..2ef2018eda8 100644 --- a/homeassistant/components/derivative/config_flow.py +++ b/homeassistant/components/derivative/config_flow.py @@ -10,11 +10,19 @@ import voluptuous as vol from homeassistant.components.counter import DOMAIN as COUNTER_DOMAIN from homeassistant.components.input_number import DOMAIN as INPUT_NUMBER_DOMAIN from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN -from homeassistant.const import CONF_NAME, CONF_SOURCE, UnitOfTime +from homeassistant.const import ( + ATTR_UNIT_OF_MEASUREMENT, + CONF_NAME, + CONF_SOURCE, + UnitOfTime, +) +from homeassistant.core import callback from homeassistant.helpers import selector from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, SchemaConfigFlowHandler, SchemaFlowFormStep, + SchemaOptionsFlowHandler, ) from .const import ( @@ -42,8 +50,43 @@ TIME_UNITS = [ UnitOfTime.DAYS, ] -OPTIONS_SCHEMA = vol.Schema( - { +ALLOWED_DOMAINS = [COUNTER_DOMAIN, INPUT_NUMBER_DOMAIN, SENSOR_DOMAIN] + + +@callback +def entity_selector_compatible( + handler: SchemaOptionsFlowHandler, +) -> selector.EntitySelector: + """Return an entity selector which compatible entities.""" + current = handler.hass.states.get(handler.options[CONF_SOURCE]) + unit_of_measurement = ( + current.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if current else None + ) + + entities = [ + ent.entity_id + for ent in handler.hass.states.async_all(ALLOWED_DOMAINS) + if ent.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == unit_of_measurement + and ent.domain in ALLOWED_DOMAINS + ] + + return selector.EntitySelector( + selector.EntitySelectorConfig(include_entities=entities) + ) + + +async def _get_options_dict(handler: SchemaCommonFlowHandler | None) -> dict: + if handler is None or not isinstance( + handler.parent_handler, SchemaOptionsFlowHandler + ): + entity_selector = selector.EntitySelector( + selector.EntitySelectorConfig(domain=ALLOWED_DOMAINS) + ) + else: + entity_selector = entity_selector_compatible(handler.parent_handler) + + return { + vol.Required(CONF_SOURCE): entity_selector, vol.Required(CONF_ROUND_DIGITS, default=2): selector.NumberSelector( selector.NumberSelectorConfig( min=0, @@ -62,25 +105,28 @@ OPTIONS_SCHEMA = vol.Schema( ), ), } -) -CONFIG_SCHEMA = vol.Schema( - { - vol.Required(CONF_NAME): selector.TextSelector(), - vol.Required(CONF_SOURCE): selector.EntitySelector( - selector.EntitySelectorConfig( - domain=[COUNTER_DOMAIN, INPUT_NUMBER_DOMAIN, SENSOR_DOMAIN] - ), - ), - } -).extend(OPTIONS_SCHEMA.schema) + +async def _get_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: + return vol.Schema(await _get_options_dict(handler)) + + +async def _get_config_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: + options = await _get_options_dict(handler) + return vol.Schema( + { + vol.Required(CONF_NAME): selector.TextSelector(), + **options, + } + ) + CONFIG_FLOW = { - "user": SchemaFlowFormStep(CONFIG_SCHEMA), + "user": SchemaFlowFormStep(_get_config_schema), } OPTIONS_FLOW = { - "init": SchemaFlowFormStep(OPTIONS_SCHEMA), + "init": SchemaFlowFormStep(_get_options_schema), } diff --git a/homeassistant/components/derivative/sensor.py b/homeassistant/components/derivative/sensor.py index d5a83035ed5..fd430c6ef4d 100644 --- a/homeassistant/components/derivative/sensor.py +++ b/homeassistant/components/derivative/sensor.py @@ -20,11 +20,8 @@ from homeassistant.const import ( UnitOfTime, ) from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback -from homeassistant.helpers import ( - config_validation as cv, - device_registry as dr, - entity_registry as er, -) +from homeassistant.helpers import config_validation as cv, entity_registry as er +from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_state_change_event @@ -90,27 +87,10 @@ async def async_setup_entry( registry, config_entry.options[CONF_SOURCE] ) - source_entity = registry.async_get(source_entity_id) - dev_reg = dr.async_get(hass) - # Resolve source entity device - if ( - (source_entity is not None) - and (source_entity.device_id is not None) - and ( - ( - device := dev_reg.async_get( - device_id=source_entity.device_id, - ) - ) - is not None - ) - ): - device_info = DeviceInfo( - identifiers=device.identifiers, - connections=device.connections, - ) - else: - device_info = None + device_info = async_device_info_to_link_from_entity( + hass, + source_entity_id, + ) if (unit_prefix := config_entry.options.get(CONF_UNIT_PREFIX)) == "none": # Before we had support for optional selectors, "none" was used for selecting nothing diff --git a/tests/components/derivative/test_config_flow.py b/tests/components/derivative/test_config_flow.py index d111df76ece..efdde93173c 100644 --- a/tests/components/derivative/test_config_flow.py +++ b/tests/components/derivative/test_config_flow.py @@ -8,6 +8,7 @@ from homeassistant import config_entries from homeassistant.components.derivative.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from homeassistant.helpers import selector from tests.common import MockConfigEntry @@ -95,6 +96,10 @@ async def test_options(hass: HomeAssistant, platform) -> None: assert await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() + hass.states.async_set("sensor.input", 10, {"unit_of_measurement": "dog"}) + hass.states.async_set("sensor.valid", 10, {"unit_of_measurement": "dog"}) + hass.states.async_set("sensor.invalid", 10, {"unit_of_measurement": "cat"}) + result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] is FlowResultType.FORM assert result["step_id"] == "init" @@ -104,9 +109,17 @@ async def test_options(hass: HomeAssistant, platform) -> None: assert get_suggested(schema, "unit_prefix") == "k" assert get_suggested(schema, "unit_time") == "min" + source = schema["source"] + assert isinstance(source, selector.EntitySelector) + assert source.config["include_entities"] == [ + "sensor.input", + "sensor.valid", + ] + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ + "source": "sensor.valid", "round": 2.0, "time_window": {"seconds": 10.0}, "unit_time": "h", @@ -116,7 +129,7 @@ async def test_options(hass: HomeAssistant, platform) -> None: assert result["data"] == { "name": "My derivative", "round": 2.0, - "source": "sensor.input", + "source": "sensor.valid", "time_window": {"seconds": 10.0}, "unit_time": "h", } @@ -124,7 +137,7 @@ async def test_options(hass: HomeAssistant, platform) -> None: assert config_entry.options == { "name": "My derivative", "round": 2.0, - "source": "sensor.input", + "source": "sensor.valid", "time_window": {"seconds": 10.0}, "unit_time": "h", } @@ -134,11 +147,11 @@ async def test_options(hass: HomeAssistant, platform) -> None: await hass.async_block_till_done() # Check the entity was updated, no new entity was created - assert len(hass.states.async_all()) == 1 + assert len(hass.states.async_all()) == 4 # Check the state of the entity has changed as expected - hass.states.async_set("sensor.input", 10, {"unit_of_measurement": "cat"}) - hass.states.async_set("sensor.input", 11, {"unit_of_measurement": "cat"}) + hass.states.async_set("sensor.valid", 10, {"unit_of_measurement": "cat"}) + hass.states.async_set("sensor.valid", 11, {"unit_of_measurement": "cat"}) await hass.async_block_till_done() state = hass.states.get(f"{platform}.my_derivative") assert state.attributes["unit_of_measurement"] == "cat/h" diff --git a/tests/components/derivative/test_init.py b/tests/components/derivative/test_init.py index 34fe385032b..32b763ee84d 100644 --- a/tests/components/derivative/test_init.py +++ b/tests/components/derivative/test_init.py @@ -4,7 +4,7 @@ import pytest from homeassistant.components.derivative.const import DOMAIN from homeassistant.core import HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from tests.common import MockConfigEntry @@ -60,3 +60,88 @@ async def test_setup_and_remove_config_entry( # Check the state and entity registry entry are removed assert hass.states.get(derivative_entity_id) is None assert entity_registry.async_get(derivative_entity_id) is None + + +async def test_device_cleaning(hass: HomeAssistant) -> None: + """Test for source entity device for Derivative.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + + # Source entity device config entry + source_config_entry = MockConfigEntry() + source_config_entry.add_to_hass(hass) + + # Device entry of the source entity + source_device1_entry = device_registry.async_get_or_create( + config_entry_id=source_config_entry.entry_id, + identifiers={("sensor", "identifier_test1")}, + connections={("mac", "30:31:32:33:34:01")}, + ) + + # Source entity registry + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=source_config_entry, + device_id=source_device1_entry.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + # Configure the configuration entry for Derivative + derivative_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "name": "Derivative", + "round": 1.0, + "source": "sensor.test_source", + "time_window": {"seconds": 0.0}, + "unit_prefix": "k", + "unit_time": "min", + }, + title="Derivative", + ) + derivative_config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(derivative_config_entry.entry_id) + await hass.async_block_till_done() + + # Confirm the link between the source entity device and the derivative sensor + derivative_entity = entity_registry.async_get("sensor.derivative") + assert derivative_entity is not None + assert derivative_entity.device_id == source_entity.device_id + + # Device entry incorrectly linked to Derivative config entry + device_registry.async_get_or_create( + config_entry_id=derivative_config_entry.entry_id, + identifiers={("sensor", "identifier_test2")}, + connections={("mac", "30:31:32:33:34:02")}, + ) + device_registry.async_get_or_create( + config_entry_id=derivative_config_entry.entry_id, + identifiers={("sensor", "identifier_test3")}, + connections={("mac", "30:31:32:33:34:03")}, + ) + await hass.async_block_till_done() + + # Before reloading the config entry, two devices are expected to be linked + devices_before_reload = device_registry.devices.get_devices_for_config_entry_id( + derivative_config_entry.entry_id + ) + assert len(devices_before_reload) == 3 + + # Config entry reload + await hass.config_entries.async_reload(derivative_config_entry.entry_id) + await hass.async_block_till_done() + + # Confirm the link between the source entity device and the derivative sensor after reload + derivative_entity = entity_registry.async_get("sensor.derivative") + assert derivative_entity is not None + assert derivative_entity.device_id == source_entity.device_id + + # After reloading the config entry, only one linked device is expected + devices_after_reload = device_registry.devices.get_devices_for_config_entry_id( + derivative_config_entry.entry_id + ) + assert len(devices_after_reload) == 1