Only wait for import flows to initialize at setup (#86106)

* Only wait for import flows to initialize at setup

* Update hassio tests

* Update hassio tests

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2023-01-18 10:44:18 +01:00 committed by GitHub
parent 767b43bb0e
commit f17a829bd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 69 additions and 27 deletions

View File

@ -266,8 +266,12 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": source}, context={"source": SOURCE_IMPORT},
data={CONF_API_KEY: entry.data[CONF_API_KEY], **geography}, data={
"import_source": source,
CONF_API_KEY: entry.data[CONF_API_KEY],
**geography,
},
) )
) )

View File

@ -171,6 +171,13 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Define the config flow to handle options.""" """Define the config flow to handle options."""
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW) return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
async def async_step_import(self, import_data: dict[str, str]) -> FlowResult:
"""Handle import of config entry version 1 data."""
import_source = import_data.pop("import_source")
if import_source == "geography_by_coords":
return await self.async_step_geography_by_coords(import_data)
return await self.async_step_geography_by_name(import_data)
async def async_step_geography_by_coords( async def async_step_geography_by_coords(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> FlowResult:

View File

@ -761,12 +761,12 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
super().__init__(hass) super().__init__(hass)
self.config_entries = config_entries self.config_entries = config_entries
self._hass_config = hass_config self._hass_config = hass_config
self._initializing: dict[str, dict[str, asyncio.Future]] = {} self._pending_import_flows: dict[str, dict[str, asyncio.Future[None]]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {} self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None: async def async_wait_import_flow_initialized(self, handler: str) -> None:
"""Wait till all flows in progress are initialized.""" """Wait till all import flows in progress are initialized."""
if not (current := self._initializing.get(handler)): if not (current := self._pending_import_flows.get(handler)):
return return
await asyncio.wait(current.values()) await asyncio.wait(current.values())
@ -783,12 +783,13 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult: ) -> FlowResult:
"""Start a configuration flow.""" """Start a configuration flow."""
if context is None: if not context or "source" not in context:
context = {} raise KeyError("Context not set or doesn't have a source set")
flow_id = uuid_util.random_uuid_hex() flow_id = uuid_util.random_uuid_hex()
init_done: asyncio.Future = asyncio.Future() if context["source"] == SOURCE_IMPORT:
self._initializing.setdefault(handler, {})[flow_id] = init_done init_done: asyncio.Future[None] = asyncio.Future()
self._pending_import_flows.setdefault(handler, {})[flow_id] = init_done
task = asyncio.create_task(self._async_init(flow_id, handler, context, data)) task = asyncio.create_task(self._async_init(flow_id, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task) self._initialize_tasks.setdefault(handler, []).append(task)
@ -797,7 +798,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
flow, result = await task flow, result = await task
finally: finally:
self._initialize_tasks[handler].remove(task) self._initialize_tasks[handler].remove(task)
self._initializing[handler].pop(flow_id) self._pending_import_flows.get(handler, {}).pop(flow_id, None)
if result["type"] != data_entry_flow.FlowResultType.ABORT: if result["type"] != data_entry_flow.FlowResultType.ABORT:
await self.async_post_init(flow, result) await self.async_post_init(flow, result)
@ -824,8 +825,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
try: try:
result = await self._async_handle_step(flow, flow.init_step, data) result = await self._async_handle_step(flow, flow.init_step, data)
finally: finally:
init_done = self._initializing[handler][flow_id] init_done = self._pending_import_flows.get(handler, {}).get(flow_id)
if not init_done.done(): if init_done and not init_done.done():
init_done.set_result(None) init_done.set_result(None)
return flow, result return flow, result
@ -845,7 +846,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
# We do this to avoid a circular dependency where async_finish_flow sets up a # We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for # new entry, which needs the integration to be set up, which is waiting for
# init to be done. # init to be done.
init_done = self._initializing[flow.handler].get(flow.flow_id) init_done = self._pending_import_flows.get(flow.handler, {}).get(flow.flow_id)
if init_done and not init_done.done(): if init_done and not init_done.done():
init_done.set_result(None) init_done.set_result(None)

View File

@ -286,7 +286,7 @@ async def _async_setup_component(
# Flush out async_setup calling create_task. Fragile but covered by test. # Flush out async_setup calling create_task. Fragile but covered by test.
await asyncio.sleep(0) await asyncio.sleep(0)
await hass.config_entries.flow.async_wait_init_flow_finish(domain) await hass.config_entries.flow.async_wait_import_flow_initialized(domain)
# Add to components before the entry.async_setup # Add to components before the entry.async_setup
# call to avoid a deadlock when forwarding platforms # call to avoid a deadlock when forwarding platforms

View File

@ -25,6 +25,8 @@ from tests.common import MockConfigEntry, load_fixture
TEST_API_KEY = "abcde12345" TEST_API_KEY = "abcde12345"
TEST_LATITUDE = 51.528308 TEST_LATITUDE = 51.528308
TEST_LONGITUDE = -0.3817765 TEST_LONGITUDE = -0.3817765
TEST_LATITUDE2 = 37.514626
TEST_LONGITUDE2 = 127.057414
COORDS_CONFIG = { COORDS_CONFIG = {
CONF_API_KEY: TEST_API_KEY, CONF_API_KEY: TEST_API_KEY,
@ -32,6 +34,12 @@ COORDS_CONFIG = {
CONF_LONGITUDE: TEST_LONGITUDE, CONF_LONGITUDE: TEST_LONGITUDE,
} }
COORDS_CONFIG2 = {
CONF_API_KEY: TEST_API_KEY,
CONF_LATITUDE: TEST_LATITUDE2,
CONF_LONGITUDE: TEST_LONGITUDE2,
}
TEST_CITY = "Beijing" TEST_CITY = "Beijing"
TEST_STATE = "Beijing" TEST_STATE = "Beijing"
TEST_COUNTRY = "China" TEST_COUNTRY = "China"

View File

@ -24,12 +24,15 @@ from homeassistant.helpers import device_registry as dr, issue_registry as ir
from .conftest import ( from .conftest import (
COORDS_CONFIG, COORDS_CONFIG,
COORDS_CONFIG2,
NAME_CONFIG, NAME_CONFIG,
TEST_API_KEY, TEST_API_KEY,
TEST_CITY, TEST_CITY,
TEST_COUNTRY, TEST_COUNTRY,
TEST_LATITUDE, TEST_LATITUDE,
TEST_LATITUDE2,
TEST_LONGITUDE, TEST_LONGITUDE,
TEST_LONGITUDE2,
TEST_STATE, TEST_STATE,
) )
@ -53,6 +56,10 @@ async def test_migration_1_2(hass, mock_pyairvisual):
CONF_STATE: TEST_STATE, CONF_STATE: TEST_STATE,
CONF_COUNTRY: TEST_COUNTRY, CONF_COUNTRY: TEST_COUNTRY,
}, },
{
CONF_LATITUDE: TEST_LATITUDE2,
CONF_LONGITUDE: TEST_LONGITUDE2,
},
], ],
}, },
version=1, version=1,
@ -63,7 +70,7 @@ async def test_migration_1_2(hass, mock_pyairvisual):
await hass.async_block_till_done() await hass.async_block_till_done()
config_entries = hass.config_entries.async_entries(DOMAIN) config_entries = hass.config_entries.async_entries(DOMAIN)
assert len(config_entries) == 2 assert len(config_entries) == 3
# Ensure that after migration, each configuration has its own config entry: # Ensure that after migration, each configuration has its own config entry:
identifier1 = f"{TEST_LATITUDE}, {TEST_LONGITUDE}" identifier1 = f"{TEST_LATITUDE}, {TEST_LONGITUDE}"
@ -82,6 +89,14 @@ async def test_migration_1_2(hass, mock_pyairvisual):
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_NAME, CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_NAME,
} }
identifier3 = f"{TEST_LATITUDE2}, {TEST_LONGITUDE2}"
assert config_entries[2].unique_id == identifier3
assert config_entries[2].title == f"Cloud API ({identifier3})"
assert config_entries[2].data == {
**COORDS_CONFIG2,
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_COORDS,
}
async def test_migration_2_3(hass, mock_pyairvisual): async def test_migration_2_3(hass, mock_pyairvisual):
"""Test migrating from version 2 to 3.""" """Test migrating from version 2 to 3."""

View File

@ -202,8 +202,9 @@ async def test_setup_api_ping(hass, aioclient_mock):
"""Test setup with API ping.""" """Test setup with API ping."""
with patch.dict(os.environ, MOCK_ENVIRON): with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {}) result = await async_setup_component(hass, "hassio", {})
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert hass.components.hassio.get_core_info()["version_latest"] == "1.0.0" assert hass.components.hassio.get_core_info()["version_latest"] == "1.0.0"
assert hass.components.hassio.is_hassio() assert hass.components.hassio.is_hassio()
@ -241,8 +242,9 @@ async def test_setup_api_push_api_data(hass, aioclient_mock):
result = await async_setup_component( result = await async_setup_component(
hass, "hassio", {"http": {"server_port": 9999}, "hassio": {}} hass, "hassio", {"http": {"server_port": 9999}, "hassio": {}}
) )
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"] assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 9999 assert aioclient_mock.mock_calls[1][2]["port"] == 9999
@ -257,8 +259,9 @@ async def test_setup_api_push_api_data_server_host(hass, aioclient_mock):
"hassio", "hassio",
{"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}}, {"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}},
) )
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"] assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 9999 assert aioclient_mock.mock_calls[1][2]["port"] == 9999
@ -269,8 +272,9 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, hass_storag
"""Test setup with API push default data.""" """Test setup with API push default data."""
with patch.dict(os.environ, MOCK_ENVIRON): with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}}) result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}})
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"] assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 8123 assert aioclient_mock.mock_calls[1][2]["port"] == 8123
@ -336,8 +340,9 @@ async def test_setup_api_existing_hassio_user(hass, aioclient_mock, hass_storage
hass_storage[STORAGE_KEY] = {"version": 1, "data": {"hassio_user": user.id}} hass_storage[STORAGE_KEY] = {"version": 1, "data": {"hassio_user": user.id}}
with patch.dict(os.environ, MOCK_ENVIRON): with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}}) result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}})
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"] assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 8123 assert aioclient_mock.mock_calls[1][2]["port"] == 8123
@ -350,8 +355,9 @@ async def test_setup_core_push_timezone(hass, aioclient_mock):
with patch.dict(os.environ, MOCK_ENVIRON): with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {"hassio": {}}) result = await async_setup_component(hass, "hassio", {"hassio": {}})
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert aioclient_mock.mock_calls[2][2]["timezone"] == "testzone" assert aioclient_mock.mock_calls[2][2]["timezone"] == "testzone"
@ -367,8 +373,9 @@ async def test_setup_hassio_no_additional_data(hass, aioclient_mock):
os.environ, {"SUPERVISOR_TOKEN": "123456"} os.environ, {"SUPERVISOR_TOKEN": "123456"}
): ):
result = await async_setup_component(hass, "hassio", {"hassio": {}}) result = await async_setup_component(hass, "hassio", {"hassio": {}})
assert result await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert aioclient_mock.mock_calls[-1][3]["Authorization"] == "Bearer 123456" assert aioclient_mock.mock_calls[-1][3]["Authorization"] == "Bearer 123456"
@ -768,9 +775,9 @@ async def test_setup_hardware_integration(hass, aioclient_mock, integration):
return_value=True, return_value=True,
) as mock_setup_entry: ) as mock_setup_entry:
result = await async_setup_component(hass, "hassio", {"hassio": {}}) result = await async_setup_component(hass, "hassio", {"hassio": {}})
assert result
await hass.async_block_till_done() await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16 assert aioclient_mock.call_count == 16
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1

View File

@ -542,8 +542,8 @@ async def test_setting_up_core_update_when_addon_fails(hass, caplog):
"hassio", "hassio",
{"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}}, {"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}},
) )
assert result await hass.async_block_till_done()
await hass.async_block_till_done() assert result
# Verify that the core update entity does exist # Verify that the core update entity does exist
state = hass.states.get("update.home_assistant_core_update") state = hass.states.get("update.home_assistant_core_update")

View File

@ -1428,7 +1428,7 @@ async def test_init_custom_integration(hass):
"homeassistant.loader.async_get_integration", "homeassistant.loader.async_get_integration",
return_value=integration, return_value=integration,
): ):
await hass.config_entries.flow.async_init("bla") await hass.config_entries.flow.async_init("bla", context={"source": "user"})
async def test_support_entry_unload(hass): async def test_support_entry_unload(hass):