From f17a829bd8a0349b316b565f3fcea3fced64f32a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 18 Jan 2023 10:44:18 +0100 Subject: [PATCH] Only wait for import flows to initialize at setup (#86106) * Only wait for import flows to initialize at setup * Update hassio tests * Update hassio tests * Apply suggestions from code review Co-authored-by: Martin Hjelmare Co-authored-by: Martin Hjelmare --- .../components/airvisual/__init__.py | 8 ++++-- .../components/airvisual/config_flow.py | 7 ++++++ homeassistant/config_entries.py | 25 ++++++++++--------- homeassistant/setup.py | 2 +- tests/components/airvisual/conftest.py | 8 ++++++ tests/components/airvisual/test_init.py | 17 ++++++++++++- tests/components/hassio/test_init.py | 23 +++++++++++------ tests/components/hassio/test_update.py | 4 +-- tests/test_config_entries.py | 2 +- 9 files changed, 69 insertions(+), 27 deletions(-) diff --git a/homeassistant/components/airvisual/__init__.py b/homeassistant/components/airvisual/__init__.py index 32c2d71292f..793b7879270 100644 --- a/homeassistant/components/airvisual/__init__.py +++ b/homeassistant/components/airvisual/__init__.py @@ -266,8 +266,12 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.async_create_task( hass.config_entries.flow.async_init( DOMAIN, - context={"source": source}, - data={CONF_API_KEY: entry.data[CONF_API_KEY], **geography}, + context={"source": SOURCE_IMPORT}, + data={ + "import_source": source, + CONF_API_KEY: entry.data[CONF_API_KEY], + **geography, + }, ) ) diff --git a/homeassistant/components/airvisual/config_flow.py b/homeassistant/components/airvisual/config_flow.py index 5d8ab5210d5..27e79f2d40b 100644 --- a/homeassistant/components/airvisual/config_flow.py +++ b/homeassistant/components/airvisual/config_flow.py @@ -171,6 +171,13 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Define the config flow to handle options.""" return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW) + async def async_step_import(self, import_data: dict[str, str]) -> FlowResult: + """Handle import of config entry version 1 data.""" + import_source = import_data.pop("import_source") + if import_source == "geography_by_coords": + return await self.async_step_geography_by_coords(import_data) + return await self.async_step_geography_by_name(import_data) + async def async_step_geography_by_coords( self, user_input: dict[str, str] | None = None ) -> FlowResult: diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 2d4774024be..c908f1916e4 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -761,12 +761,12 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): super().__init__(hass) self.config_entries = config_entries self._hass_config = hass_config - self._initializing: dict[str, dict[str, asyncio.Future]] = {} + self._pending_import_flows: dict[str, dict[str, asyncio.Future[None]]] = {} self._initialize_tasks: dict[str, list[asyncio.Task]] = {} - async def async_wait_init_flow_finish(self, handler: str) -> None: - """Wait till all flows in progress are initialized.""" - if not (current := self._initializing.get(handler)): + async def async_wait_import_flow_initialized(self, handler: str) -> None: + """Wait till all import flows in progress are initialized.""" + if not (current := self._pending_import_flows.get(handler)): return await asyncio.wait(current.values()) @@ -783,12 +783,13 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None ) -> FlowResult: """Start a configuration flow.""" - if context is None: - context = {} + if not context or "source" not in context: + raise KeyError("Context not set or doesn't have a source set") flow_id = uuid_util.random_uuid_hex() - init_done: asyncio.Future = asyncio.Future() - self._initializing.setdefault(handler, {})[flow_id] = init_done + if context["source"] == SOURCE_IMPORT: + init_done: asyncio.Future[None] = asyncio.Future() + self._pending_import_flows.setdefault(handler, {})[flow_id] = init_done task = asyncio.create_task(self._async_init(flow_id, handler, context, data)) self._initialize_tasks.setdefault(handler, []).append(task) @@ -797,7 +798,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): flow, result = await task finally: self._initialize_tasks[handler].remove(task) - self._initializing[handler].pop(flow_id) + self._pending_import_flows.get(handler, {}).pop(flow_id, None) if result["type"] != data_entry_flow.FlowResultType.ABORT: await self.async_post_init(flow, result) @@ -824,8 +825,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): try: result = await self._async_handle_step(flow, flow.init_step, data) finally: - init_done = self._initializing[handler][flow_id] - if not init_done.done(): + init_done = self._pending_import_flows.get(handler, {}).get(flow_id) + if init_done and not init_done.done(): init_done.set_result(None) return flow, result @@ -845,7 +846,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): # We do this to avoid a circular dependency where async_finish_flow sets up a # new entry, which needs the integration to be set up, which is waiting for # init to be done. - init_done = self._initializing[flow.handler].get(flow.flow_id) + init_done = self._pending_import_flows.get(flow.handler, {}).get(flow.flow_id) if init_done and not init_done.done(): init_done.set_result(None) diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 94aa3ab1b03..9740d338eff 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -286,7 +286,7 @@ async def _async_setup_component( # Flush out async_setup calling create_task. Fragile but covered by test. await asyncio.sleep(0) - await hass.config_entries.flow.async_wait_init_flow_finish(domain) + await hass.config_entries.flow.async_wait_import_flow_initialized(domain) # Add to components before the entry.async_setup # call to avoid a deadlock when forwarding platforms diff --git a/tests/components/airvisual/conftest.py b/tests/components/airvisual/conftest.py index eb2ba2d82c9..c85d6e90c4b 100644 --- a/tests/components/airvisual/conftest.py +++ b/tests/components/airvisual/conftest.py @@ -25,6 +25,8 @@ from tests.common import MockConfigEntry, load_fixture TEST_API_KEY = "abcde12345" TEST_LATITUDE = 51.528308 TEST_LONGITUDE = -0.3817765 +TEST_LATITUDE2 = 37.514626 +TEST_LONGITUDE2 = 127.057414 COORDS_CONFIG = { CONF_API_KEY: TEST_API_KEY, @@ -32,6 +34,12 @@ COORDS_CONFIG = { CONF_LONGITUDE: TEST_LONGITUDE, } +COORDS_CONFIG2 = { + CONF_API_KEY: TEST_API_KEY, + CONF_LATITUDE: TEST_LATITUDE2, + CONF_LONGITUDE: TEST_LONGITUDE2, +} + TEST_CITY = "Beijing" TEST_STATE = "Beijing" TEST_COUNTRY = "China" diff --git a/tests/components/airvisual/test_init.py b/tests/components/airvisual/test_init.py index a02543dc7f1..b9459f5608b 100644 --- a/tests/components/airvisual/test_init.py +++ b/tests/components/airvisual/test_init.py @@ -24,12 +24,15 @@ from homeassistant.helpers import device_registry as dr, issue_registry as ir from .conftest import ( COORDS_CONFIG, + COORDS_CONFIG2, NAME_CONFIG, TEST_API_KEY, TEST_CITY, TEST_COUNTRY, TEST_LATITUDE, + TEST_LATITUDE2, TEST_LONGITUDE, + TEST_LONGITUDE2, TEST_STATE, ) @@ -53,6 +56,10 @@ async def test_migration_1_2(hass, mock_pyairvisual): CONF_STATE: TEST_STATE, CONF_COUNTRY: TEST_COUNTRY, }, + { + CONF_LATITUDE: TEST_LATITUDE2, + CONF_LONGITUDE: TEST_LONGITUDE2, + }, ], }, version=1, @@ -63,7 +70,7 @@ async def test_migration_1_2(hass, mock_pyairvisual): await hass.async_block_till_done() config_entries = hass.config_entries.async_entries(DOMAIN) - assert len(config_entries) == 2 + assert len(config_entries) == 3 # Ensure that after migration, each configuration has its own config entry: identifier1 = f"{TEST_LATITUDE}, {TEST_LONGITUDE}" @@ -82,6 +89,14 @@ async def test_migration_1_2(hass, mock_pyairvisual): CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_NAME, } + identifier3 = f"{TEST_LATITUDE2}, {TEST_LONGITUDE2}" + assert config_entries[2].unique_id == identifier3 + assert config_entries[2].title == f"Cloud API ({identifier3})" + assert config_entries[2].data == { + **COORDS_CONFIG2, + CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_COORDS, + } + async def test_migration_2_3(hass, mock_pyairvisual): """Test migrating from version 2 to 3.""" diff --git a/tests/components/hassio/test_init.py b/tests/components/hassio/test_init.py index 371398e32c9..58e4fc1552d 100644 --- a/tests/components/hassio/test_init.py +++ b/tests/components/hassio/test_init.py @@ -202,8 +202,9 @@ async def test_setup_api_ping(hass, aioclient_mock): """Test setup with API ping.""" with patch.dict(os.environ, MOCK_ENVIRON): result = await async_setup_component(hass, "hassio", {}) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert hass.components.hassio.get_core_info()["version_latest"] == "1.0.0" assert hass.components.hassio.is_hassio() @@ -241,8 +242,9 @@ async def test_setup_api_push_api_data(hass, aioclient_mock): result = await async_setup_component( hass, "hassio", {"http": {"server_port": 9999}, "hassio": {}} ) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert not aioclient_mock.mock_calls[1][2]["ssl"] assert aioclient_mock.mock_calls[1][2]["port"] == 9999 @@ -257,8 +259,9 @@ async def test_setup_api_push_api_data_server_host(hass, aioclient_mock): "hassio", {"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}}, ) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert not aioclient_mock.mock_calls[1][2]["ssl"] assert aioclient_mock.mock_calls[1][2]["port"] == 9999 @@ -269,8 +272,9 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, hass_storag """Test setup with API push default data.""" with patch.dict(os.environ, MOCK_ENVIRON): result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}}) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert not aioclient_mock.mock_calls[1][2]["ssl"] assert aioclient_mock.mock_calls[1][2]["port"] == 8123 @@ -336,8 +340,9 @@ async def test_setup_api_existing_hassio_user(hass, aioclient_mock, hass_storage hass_storage[STORAGE_KEY] = {"version": 1, "data": {"hassio_user": user.id}} with patch.dict(os.environ, MOCK_ENVIRON): result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}}) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert not aioclient_mock.mock_calls[1][2]["ssl"] assert aioclient_mock.mock_calls[1][2]["port"] == 8123 @@ -350,8 +355,9 @@ async def test_setup_core_push_timezone(hass, aioclient_mock): with patch.dict(os.environ, MOCK_ENVIRON): result = await async_setup_component(hass, "hassio", {"hassio": {}}) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert aioclient_mock.mock_calls[2][2]["timezone"] == "testzone" @@ -367,8 +373,9 @@ async def test_setup_hassio_no_additional_data(hass, aioclient_mock): os.environ, {"SUPERVISOR_TOKEN": "123456"} ): result = await async_setup_component(hass, "hassio", {"hassio": {}}) - assert result + await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert aioclient_mock.mock_calls[-1][3]["Authorization"] == "Bearer 123456" @@ -768,9 +775,9 @@ async def test_setup_hardware_integration(hass, aioclient_mock, integration): return_value=True, ) as mock_setup_entry: result = await async_setup_component(hass, "hassio", {"hassio": {}}) - assert result await hass.async_block_till_done() + assert result assert aioclient_mock.call_count == 16 assert len(mock_setup_entry.mock_calls) == 1 diff --git a/tests/components/hassio/test_update.py b/tests/components/hassio/test_update.py index 02d6b1dbf6b..8391ea66b5d 100644 --- a/tests/components/hassio/test_update.py +++ b/tests/components/hassio/test_update.py @@ -542,8 +542,8 @@ async def test_setting_up_core_update_when_addon_fails(hass, caplog): "hassio", {"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}}, ) - assert result - await hass.async_block_till_done() + await hass.async_block_till_done() + assert result # Verify that the core update entity does exist state = hass.states.get("update.home_assistant_core_update") diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 198e79ec189..2943c2b9c57 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1428,7 +1428,7 @@ async def test_init_custom_integration(hass): "homeassistant.loader.async_get_integration", return_value=integration, ): - await hass.config_entries.flow.async_init("bla") + await hass.config_entries.flow.async_init("bla", context={"source": "user"}) async def test_support_entry_unload(hass):