diff --git a/homeassistant/components/airvisual/__init__.py b/homeassistant/components/airvisual/__init__.py index e234c2b1c67..4352b15b8a5 100644 --- a/homeassistant/components/airvisual/__init__.py +++ b/homeassistant/components/airvisual/__init__.py @@ -22,6 +22,7 @@ from homeassistant.helpers.event import async_track_time_interval from .const import ( CONF_CITY, CONF_COUNTRY, + CONF_GEOGRAPHIES, DATA_CLIENT, DEFAULT_SCAN_INTERVAL, DOMAIN, @@ -34,8 +35,6 @@ DATA_LISTENER = "listener" DEFAULT_OPTIONS = {CONF_SHOW_ON_MAP: True} -CONF_GEOGRAPHIES = "geographies" - GEOGRAPHY_COORDINATES_SCHEMA = vol.Schema( { vol.Required(CONF_LATITUDE): cv.latitude, @@ -158,8 +157,7 @@ async def async_migrate_entry(hass, config_entry): # Update the config entry to only include the first geography (there is always # guaranteed to be at least one): - data = {**config_entry.data} - geographies = data.pop(CONF_GEOGRAPHIES) + geographies = list(config_entry.data[CONF_GEOGRAPHIES]) first_geography = geographies.pop(0) first_id = async_get_geography_id(first_geography) diff --git a/homeassistant/components/airvisual/config_flow.py b/homeassistant/components/airvisual/config_flow.py index 047f585a4ff..0c9c0e65ff1 100644 --- a/homeassistant/components/airvisual/config_flow.py +++ b/homeassistant/components/airvisual/config_flow.py @@ -16,7 +16,7 @@ from homeassistant.core import callback from homeassistant.helpers import aiohttp_client, config_validation as cv from . import async_get_geography_id -from .const import DOMAIN # pylint: disable=unused-import +from .const import CONF_GEOGRAPHIES, DOMAIN # pylint: disable=unused-import class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): @@ -69,6 +69,18 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): geo_id = async_get_geography_id(user_input) await self._async_set_unique_id(geo_id) + self._abort_if_unique_id_configured() + + # Find older config entries without unique ID + for entry in self._async_current_entries(): + if entry.version != 1: + continue + + if any( + geo_id == async_get_geography_id(geography) + for geography in entry.data[CONF_GEOGRAPHIES] + ): + return self.async_abort(reason="already_configured") websession = aiohttp_client.async_get_clientsession(self.hass) client = Client(websession, api_key=user_input[CONF_API_KEY]) @@ -90,9 +102,10 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ) checked_keys.add(user_input[CONF_API_KEY]) - return self.async_create_entry( - title=f"Cloud API ({geo_id})", data=user_input - ) + + return self.async_create_entry( + title=f"Cloud API ({geo_id})", data=user_input + ) class AirVisualOptionsFlowHandler(config_entries.OptionsFlow): diff --git a/homeassistant/components/airvisual/const.py b/homeassistant/components/airvisual/const.py index 3bfc224a735..ab54e191116 100644 --- a/homeassistant/components/airvisual/const.py +++ b/homeassistant/components/airvisual/const.py @@ -5,6 +5,7 @@ DOMAIN = "airvisual" CONF_CITY = "city" CONF_COUNTRY = "country" +CONF_GEOGRAPHIES = "geographies" DATA_CLIENT = "client" diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 20763dd39a5..51f083b7eeb 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -1,5 +1,6 @@ """Classes to help gather user submissions.""" import abc +import asyncio import logging from typing import Any, Dict, List, Optional, cast import uuid @@ -53,8 +54,18 @@ class FlowManager(abc.ABC): def __init__(self, hass: HomeAssistant,) -> None: """Initialize the flow manager.""" self.hass = hass + self._initializing: Dict[str, List[asyncio.Future]] = {} self._progress: Dict[str, Any] = {} + async def async_wait_init_flow_finish(self, handler: str) -> None: + """Wait till all flows in progress are initialized.""" + current = self._initializing.get(handler) + + if not current: + return + + await asyncio.wait(current) + @abc.abstractmethod async def async_create_flow( self, @@ -94,8 +105,13 @@ class FlowManager(abc.ABC): """Start a configuration flow.""" if context is None: context = {} + + init_done: asyncio.Future = asyncio.Future() + self._initializing.setdefault(handler, []).append(init_done) + flow = await self.async_create_flow(handler, context=context, data=data) if not flow: + self._initializing[handler].remove(init_done) raise UnknownFlow("Flow was not created") flow.hass = self.hass flow.handler = handler @@ -103,7 +119,12 @@ class FlowManager(abc.ABC): flow.context = context self._progress[flow.flow_id] = flow - result = await self._async_handle_step(flow, flow.init_step, data) + try: + result = await self._async_handle_step( + flow, flow.init_step, data, init_done + ) + finally: + self._initializing[handler].remove(init_done) if result["type"] != RESULT_TYPE_ABORT: await self.async_post_init(flow, result) @@ -154,13 +175,19 @@ class FlowManager(abc.ABC): raise UnknownFlow async def _async_handle_step( - self, flow: Any, step_id: str, user_input: Optional[Dict] + self, + flow: Any, + step_id: str, + user_input: Optional[Dict], + step_done: Optional[asyncio.Future] = None, ) -> Dict: """Handle a step of a flow.""" method = f"async_step_{step_id}" if not hasattr(flow, method): self._progress.pop(flow.flow_id) + if step_done: + step_done.set_result(None) raise UnknownStep( f"Handler {flow.__class__.__name__} doesn't support step {step_id}" ) @@ -172,6 +199,13 @@ class FlowManager(abc.ABC): flow.flow_id, flow.handler, err.reason, err.description_placeholders ) + # Mark the step as done. + # We do this before calling async_finish_flow because config entries will hit a + # circular dependency where async_finish_flow sets up new entry, which needs the + # integration to be set up, which is waiting for init to be done. + if step_done: + step_done.set_result(None) + if result["type"] not in ( RESULT_TYPE_FORM, RESULT_TYPE_EXTERNAL_STEP, diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 40d767728d3..82b7e1be039 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -197,9 +197,12 @@ async def _async_setup_component( ) return False - if hass.config_entries: - for entry in hass.config_entries.async_entries(domain): - await entry.async_setup(hass, integration=integration) + # Flush out async_setup calling create_task. Fragile but covered by test. + await asyncio.sleep(0) + await hass.config_entries.flow.async_wait_init_flow_finish(domain) + + for entry in hass.config_entries.async_entries(domain): + await entry.async_setup(hass, integration=integration) hass.config.components.add(domain) diff --git a/tests/components/axis/test_init.py b/tests/components/axis/test_init.py index 83e1337b079..d11f8c91fc3 100644 --- a/tests/components/axis/test_init.py +++ b/tests/components/axis/test_init.py @@ -9,19 +9,6 @@ from .test_device import MAC, setup_axis_integration from tests.common import MockConfigEntry, mock_coro -async def test_setup_device_already_configured(hass): - """Test already configured device does not configure a second.""" - with patch.object(hass, "config_entries") as mock_config_entries: - - assert await async_setup_component( - hass, - axis.DOMAIN, - {axis.DOMAIN: {"device_name": {axis.CONF_HOST: "1.2.3.4"}}}, - ) - - assert not mock_config_entries.flow.mock_calls - - async def test_setup_no_config(hass): """Test setup without configuration.""" assert await async_setup_component(hass, axis.DOMAIN, {}) diff --git a/tests/components/konnected/test_init.py b/tests/components/konnected/test_init.py index 28071291266..1bf239852f8 100644 --- a/tests/components/konnected/test_init.py +++ b/tests/components/konnected/test_init.py @@ -230,7 +230,7 @@ async def test_setup_with_no_config(hass): assert konnected.YAML_CONFIGS not in hass.data[konnected.DOMAIN] -async def test_setup_defined_hosts_known_auth(hass): +async def test_setup_defined_hosts_known_auth(hass, mock_panel): """Test we don't initiate a config entry if configured panel is known.""" MockConfigEntry( domain="konnected", diff --git a/tests/components/luftdaten/test_init.py b/tests/components/luftdaten/test_init.py index fe7bca654c3..ebe5f73669e 100644 --- a/tests/components/luftdaten/test_init.py +++ b/tests/components/luftdaten/test_init.py @@ -15,7 +15,9 @@ async def test_config_with_sensor_passed_to_config_entry(hass): CONF_SCAN_INTERVAL: 600, } - with patch.object(hass, "config_entries") as mock_config_entries, patch.object( + with patch.object( + hass.config_entries.flow, "async_init" + ) as mock_config_entries, patch.object( luftdaten, "configured_sensors", return_value=[] ): assert await async_setup_component(hass, DOMAIN, conf) is True @@ -27,7 +29,9 @@ async def test_config_already_registered_not_passed_to_config_entry(hass): """Test that an already registered sensor does not initiate an import.""" conf = {CONF_SENSOR_ID: "12345abcde"} - with patch.object(hass, "config_entries") as mock_config_entries, patch.object( + with patch.object( + hass.config_entries.flow, "async_init" + ) as mock_config_entries, patch.object( luftdaten, "configured_sensors", return_value=["12345abcde"] ): assert await async_setup_component(hass, DOMAIN, conf) is True diff --git a/tests/components/zeroconf/test_init.py b/tests/components/zeroconf/test_init.py index 4e086978be1..6a3dc1f5941 100644 --- a/tests/components/zeroconf/test_init.py +++ b/tests/components/zeroconf/test_init.py @@ -55,7 +55,9 @@ def get_homekit_info_mock(model): async def test_setup(hass, mock_zeroconf): """Test configured options for a device are loaded via config entry.""" - with patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( + with patch.object( + hass.config_entries.flow, "async_init" + ) as mock_config_flow, patch.object( zeroconf, "ServiceBrowser", side_effect=service_update_mock ) as mock_service_browser: mock_zeroconf.get_service_info.side_effect = get_service_info_mock @@ -72,7 +74,9 @@ async def test_homekit_match_partial_space(hass, mock_zeroconf): """Test configured options for a device are loaded via config entry.""" with patch.dict( zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True - ), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( + ), patch.object( + hass.config_entries.flow, "async_init" + ) as mock_config_flow, patch.object( zeroconf, "ServiceBrowser", side_effect=service_update_mock ) as mock_service_browser: mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("LIFX bulb") @@ -87,7 +91,9 @@ async def test_homekit_match_partial_dash(hass, mock_zeroconf): """Test configured options for a device are loaded via config entry.""" with patch.dict( zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True - ), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( + ), patch.object( + hass.config_entries.flow, "async_init" + ) as mock_config_flow, patch.object( zeroconf, "ServiceBrowser", side_effect=service_update_mock ) as mock_service_browser: mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock( @@ -104,7 +110,9 @@ async def test_homekit_match_full(hass, mock_zeroconf): """Test configured options for a device are loaded via config entry.""" with patch.dict( zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True - ), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( + ), patch.object( + hass.config_entries.flow, "async_init" + ) as mock_config_flow, patch.object( zeroconf, "ServiceBrowser", side_effect=service_update_mock ) as mock_service_browser: mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("BSB002") diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 86eb3919f00..28746bbfbe0 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3,6 +3,7 @@ import asyncio from datetime import timedelta from unittest.mock import MagicMock, patch +from asynctest import CoroutineMock import pytest from homeassistant import config_entries, data_entry_flow, loader @@ -1463,3 +1464,97 @@ async def test_partial_flows_hidden(hass, manager): await hass.async_block_till_done() state = hass.states.get("persistent_notification.config_entry_discovery") assert state is not None + + +async def test_async_setup_init_entry(hass): + """Test a config entry being initialized during integration setup.""" + + async def mock_async_setup(hass, config): + """Mock setup.""" + hass.async_create_task( + hass.config_entries.flow.async_init( + "comp", context={"source": config_entries.SOURCE_IMPORT}, data={}, + ) + ) + return True + + async_setup_entry = CoroutineMock(return_value=True) + mock_integration( + hass, + MockModule( + "comp", async_setup=mock_async_setup, async_setup_entry=async_setup_entry + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + await async_setup_component(hass, "persistent_notification", {}) + + class TestFlow(config_entries.ConfigFlow): + """Test flow.""" + + VERSION = 1 + + async def async_step_import(self, user_input): + """Test import step creating entry.""" + return self.async_create_entry(title="title", data={}) + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + assert await async_setup_component(hass, "comp", {}) + + await hass.async_block_till_done() + + assert len(async_setup_entry.mock_calls) == 1 + + entries = hass.config_entries.async_entries("comp") + assert len(entries) == 1 + assert entries[0].state == config_entries.ENTRY_STATE_LOADED + + +async def test_async_setup_update_entry(hass): + """Test a config entry being updated during integration setup.""" + entry = MockConfigEntry(domain="comp", data={"value": "initial"}) + entry.add_to_hass(hass) + + async def mock_async_setup(hass, config): + """Mock setup.""" + hass.async_create_task( + hass.config_entries.flow.async_init( + "comp", context={"source": config_entries.SOURCE_IMPORT}, data={}, + ) + ) + return True + + async def mock_async_setup_entry(hass, entry): + """Mock setting up an entry.""" + assert entry.data["value"] == "updated" + return True + + mock_integration( + hass, + MockModule( + "comp", + async_setup=mock_async_setup, + async_setup_entry=mock_async_setup_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + await async_setup_component(hass, "persistent_notification", {}) + + class TestFlow(config_entries.ConfigFlow): + """Test flow.""" + + VERSION = 1 + + async def async_step_import(self, user_input): + """Test import step updating existing entry.""" + self.hass.config_entries.async_update_entry( + entry, data={"value": "updated"} + ) + return self.async_abort(reason="yo") + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + assert await async_setup_component(hass, "comp", {}) + + entries = hass.config_entries.async_entries("comp") + assert len(entries) == 1 + assert entries[0].state == config_entries.ENTRY_STATE_LOADED + assert entries[0].data == {"value": "updated"}