diff --git a/homeassistant/components/integration/config_flow.py b/homeassistant/components/integration/config_flow.py index 20c1b920ec7..dcf67a6b5ef 100644 --- a/homeassistant/components/integration/config_flow.py +++ b/homeassistant/components/integration/config_flow.py @@ -10,11 +10,19 @@ import voluptuous as vol from homeassistant.components.counter import DOMAIN as COUNTER_DOMAIN from homeassistant.components.input_number import DOMAIN as INPUT_NUMBER_DOMAIN from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN -from homeassistant.const import CONF_METHOD, CONF_NAME, UnitOfTime +from homeassistant.const import ( + ATTR_UNIT_OF_MEASUREMENT, + CONF_METHOD, + CONF_NAME, + UnitOfTime, +) +from homeassistant.core import callback from homeassistant.helpers import selector from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, SchemaConfigFlowHandler, SchemaFlowFormStep, + SchemaOptionsFlowHandler, ) from .const import ( @@ -45,25 +53,43 @@ INTEGRATION_METHODS = [ METHOD_LEFT, METHOD_RIGHT, ] +ALLOWED_DOMAINS = [COUNTER_DOMAIN, INPUT_NUMBER_DOMAIN, SENSOR_DOMAIN] -OPTIONS_SCHEMA = vol.Schema( - { - vol.Optional(CONF_ROUND_DIGITS): selector.NumberSelector( - selector.NumberSelectorConfig( - min=0, max=6, mode=selector.NumberSelectorMode.BOX - ), - ), - } -) -CONFIG_SCHEMA = vol.Schema( - { - vol.Required(CONF_NAME): selector.TextSelector(), - vol.Required(CONF_SOURCE_SENSOR): selector.EntitySelector( - selector.EntitySelectorConfig( - domain=[COUNTER_DOMAIN, INPUT_NUMBER_DOMAIN, SENSOR_DOMAIN] - ), - ), +@callback +def entity_selector_compatible( + handler: SchemaOptionsFlowHandler, +) -> selector.EntitySelector: + """Return an entity selector which compatible entities.""" + current = handler.hass.states.get(handler.options[CONF_SOURCE_SENSOR]) + unit_of_measurement = ( + current.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if current else None + ) + + entities = [ + ent.entity_id + for ent in handler.hass.states.async_all(ALLOWED_DOMAINS) + if ent.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == unit_of_measurement + and ent.domain in ALLOWED_DOMAINS + ] + + return selector.EntitySelector( + selector.EntitySelectorConfig(include_entities=entities) + ) + + +async def _get_options_dict(handler: SchemaCommonFlowHandler | None) -> dict: + if handler is None or not isinstance( + handler.parent_handler, SchemaOptionsFlowHandler + ): + entity_selector = selector.EntitySelector( + selector.EntitySelectorConfig(domain=ALLOWED_DOMAINS) + ) + else: + entity_selector = entity_selector_compatible(handler.parent_handler) + + return { + vol.Required(CONF_SOURCE_SENSOR): entity_selector, vol.Required(CONF_METHOD, default=METHOD_TRAPEZOIDAL): selector.SelectSelector( selector.SelectSelectorConfig( options=INTEGRATION_METHODS, translation_key=CONF_METHOD @@ -71,31 +97,46 @@ CONFIG_SCHEMA = vol.Schema( ), vol.Optional(CONF_ROUND_DIGITS): selector.NumberSelector( selector.NumberSelectorConfig( - min=0, - max=6, - mode=selector.NumberSelectorMode.BOX, - unit_of_measurement="decimals", - ), - ), - vol.Optional(CONF_UNIT_PREFIX): selector.SelectSelector( - selector.SelectSelectorConfig(options=UNIT_PREFIXES), - ), - vol.Required(CONF_UNIT_TIME, default=UnitOfTime.HOURS): selector.SelectSelector( - selector.SelectSelectorConfig( - options=TIME_UNITS, - mode=selector.SelectSelectorMode.DROPDOWN, - translation_key=CONF_UNIT_TIME, + min=0, max=6, mode=selector.NumberSelectorMode.BOX ), ), } -) + + +async def _get_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: + return vol.Schema(await _get_options_dict(handler)) + + +async def _get_config_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: + options = await _get_options_dict(handler) + return vol.Schema( + { + vol.Required(CONF_NAME): selector.TextSelector(), + vol.Optional(CONF_UNIT_PREFIX): selector.SelectSelector( + selector.SelectSelectorConfig( + options=UNIT_PREFIXES, mode=selector.SelectSelectorMode.DROPDOWN + ) + ), + vol.Required( + CONF_UNIT_TIME, default=UnitOfTime.HOURS + ): selector.SelectSelector( + selector.SelectSelectorConfig( + options=TIME_UNITS, + mode=selector.SelectSelectorMode.DROPDOWN, + translation_key=CONF_UNIT_TIME, + ), + ), + **options, + } + ) + CONFIG_FLOW = { - "user": SchemaFlowFormStep(CONFIG_SCHEMA), + "user": SchemaFlowFormStep(_get_config_schema), } OPTIONS_FLOW = { - "init": SchemaFlowFormStep(OPTIONS_SCHEMA), + "init": SchemaFlowFormStep(_get_options_schema), } diff --git a/homeassistant/components/integration/strings.json b/homeassistant/components/integration/strings.json index 74c2b3ee440..0f5231399b7 100644 --- a/homeassistant/components/integration/strings.json +++ b/homeassistant/components/integration/strings.json @@ -25,10 +25,16 @@ "step": { "init": { "data": { - "round": "[%key:component::integration::config::step::user::data::round%]" + "method": "[%key:component::integration::config::step::user::data::method%]", + "round": "[%key:component::integration::config::step::user::data::round%]", + "source": "[%key:component::integration::config::step::user::data::source%]", + "unit_prefix": "[%key:component::integration::config::step::user::data::unit_prefix%]", + "unit_time": "[%key:component::integration::config::step::user::data::unit_time%]" }, "data_description": { - "round": "[%key:component::integration::config::step::user::data_description::round%]" + "round": "[%key:component::integration::config::step::user::data_description::round%]", + "unit_prefix": "[%key:component::integration::config::step::user::data_description::unit_prefix%]", + "unit_time": "[%key:component::integration::config::step::user::data_description::unit_time%]" } } } diff --git a/tests/components/integration/test_config_flow.py b/tests/components/integration/test_config_flow.py index 179984f20f2..ede2146185d 100644 --- a/tests/components/integration/test_config_flow.py +++ b/tests/components/integration/test_config_flow.py @@ -8,6 +8,7 @@ from homeassistant import config_entries from homeassistant.components.integration.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from homeassistant.helpers import selector from tests.common import MockConfigEntry @@ -95,21 +96,34 @@ async def test_options(hass: HomeAssistant, platform) -> None: assert await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() + hass.states.async_set("sensor.input", 10, {"unit_of_measurement": "dog"}) + hass.states.async_set("sensor.valid", 10, {"unit_of_measurement": "dog"}) + hass.states.async_set("sensor.invalid", 10, {"unit_of_measurement": "cat"}) + result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] is FlowResultType.FORM assert result["step_id"] == "init" schema = result["data_schema"].schema assert get_suggested(schema, "round") == 1.0 + source = schema["source"] + assert isinstance(source, selector.EntitySelector) + assert source.config["include_entities"] == [ + "sensor.input", + "sensor.valid", + ] + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ + "method": "right", "round": 2.0, + "source": "sensor.input", }, ) assert result["type"] is FlowResultType.CREATE_ENTRY assert result["data"] == { - "method": "left", + "method": "right", "name": "My integration", "round": 2.0, "source": "sensor.input", @@ -118,7 +132,7 @@ async def test_options(hass: HomeAssistant, platform) -> None: } assert config_entry.data == {} assert config_entry.options == { - "method": "left", + "method": "right", "name": "My integration", "round": 2.0, "source": "sensor.input", @@ -131,7 +145,7 @@ async def test_options(hass: HomeAssistant, platform) -> None: await hass.async_block_till_done() # Check the entity was updated, no new entity was created - assert len(hass.states.async_all()) == 1 + assert len(hass.states.async_all()) == 4 # Check the state of the entity has changed as expected hass.states.async_set("sensor.input", 10, {"unit_of_measurement": "dog"})