From 46ceccfbb35ffc1385c4786b02adbab2c6c8b0ed Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Tue, 29 Oct 2024 20:26:34 +0000 Subject: [PATCH] Use new try_connect_all discover command in tplink config flow (#128994) Co-authored-by: J. Nick Koston --- .../components/tplink/config_flow.py | 134 +++++++++++----- tests/components/tplink/conftest.py | 2 + tests/components/tplink/test_config_flow.py | 149 +++++++++++++++++- 3 files changed, 240 insertions(+), 45 deletions(-) diff --git a/homeassistant/components/tplink/config_flow.py b/homeassistant/components/tplink/config_flow.py index 611ab3ac9fc..a9f665e12fd 100644 --- a/homeassistant/components/tplink/config_flow.py +++ b/homeassistant/components/tplink/config_flow.py @@ -162,12 +162,16 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_abort(reason="already_in_progress") credentials = await get_credentials(self.hass) try: + # If integration discovery there will be a device or None for dhcp if device: self._discovered_device = device await self._async_try_connect(device, credentials) else: await self._async_try_discover_and_update( - host, credentials, raise_on_progress=True + host, + credentials, + raise_on_progress=True, + raise_on_timeout=True, ) except AuthenticationError: return await self.async_step_discovery_auth_confirm() @@ -271,7 +275,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): credentials = await get_credentials(self.hass) try: device = await self._async_try_discover_and_update( - host, credentials, raise_on_progress=False + host, credentials, raise_on_progress=False, raise_on_timeout=False + ) or await self._async_try_connect_all( + host, credentials=credentials, raise_on_progress=False ) except AuthenticationError: return await self.async_step_user_auth_confirm() @@ -279,6 +285,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): errors["base"] = "cannot_connect" placeholders["error"] = str(ex) else: + if not device: + return await self.async_step_user_auth_confirm() return self._async_create_entry_from_device(device) return self.async_show_form( @@ -298,15 +306,20 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): assert self.host is not None placeholders: dict[str, str] = {CONF_HOST: self.host} - assert self._discovered_device is not None if user_input: username = user_input[CONF_USERNAME] password = user_input[CONF_PASSWORD] credentials = Credentials(username, password) + device: Device | None try: - device = await self._async_try_connect( - self._discovered_device, credentials - ) + if self._discovered_device: + device = await self._async_try_connect( + self._discovered_device, credentials + ) + else: + device = await self._async_try_connect_all( + self.host, credentials=credentials, raise_on_progress=False + ) except AuthenticationError as ex: errors[CONF_PASSWORD] = "invalid_auth" placeholders["error"] = str(ex) @@ -314,11 +327,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): errors["base"] = "cannot_connect" placeholders["error"] = str(ex) else: - await set_credentials(self.hass, username, password) - self.hass.async_create_task( - self._async_reload_requires_auth_entries(), eager_start=False - ) - return self._async_create_entry_from_device(device) + if not device: + errors["base"] = "cannot_connect" + placeholders["error"] = "try_connect_all failed" + else: + await set_credentials(self.hass, username, password) + self.hass.async_create_task( + self._async_reload_requires_auth_entries(), eager_start=False + ) + return self._async_create_entry_from_device(device) return self.async_show_form( step_id="user_auth_confirm", @@ -408,46 +425,68 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): data=data, ) + async def _async_try_connect_all( + self, + host: str, + credentials: Credentials | None, + raise_on_progress: bool, + ) -> Device | None: + """Try to connect to the device speculatively. + + The connection parameters aren't known but discovery has failed so try + to connect with tcp. + """ + if credentials: + device = await Discover.try_connect_all( + host, + credentials=credentials, + http_client=create_async_tplink_clientsession(self.hass), + ) + else: + # This will just try the legacy protocol that doesn't require auth + # and doesn't use http + try: + device = await Device.connect(config=DeviceConfig(host)) + except Exception: # noqa: BLE001 + return None + if device: + await self.async_set_unique_id( + dr.format_mac(device.mac), + raise_on_progress=raise_on_progress, + ) + return device + async def _async_try_discover_and_update( self, host: str, credentials: Credentials | None, raise_on_progress: bool, - ) -> Device: + raise_on_timeout: bool, + ) -> Device | None: """Try to discover the device and call update. - Will try to connect to legacy devices if discovery fails. + Will try to connect directly if discovery fails. """ + self._discovered_device = None try: self._discovered_device = await Discover.discover_single( host, credentials=credentials ) except TimeoutError as ex: - # Try connect() to legacy devices if discovery fails. This is a - # fallback mechanism for legacy that can handle connections without - # discovery info but if it fails raise the original error which is - # applicable for newer devices. - try: - self._discovered_device = await Device.connect( - config=DeviceConfig(host) - ) - except Exception: # noqa: BLE001 - # Raise the original error instead of the fallback error + if raise_on_timeout: raise ex from ex - else: - if TYPE_CHECKING: - # device or exception is always returned unless - # on_unsupported callback was passed to discover_single - assert self._discovered_device - if self._discovered_device.config.uses_http: - self._discovered_device.config.http_client = ( - create_async_tplink_clientsession(self.hass) - ) - await self._discovered_device.update() + return None + if TYPE_CHECKING: + assert self._discovered_device await self.async_set_unique_id( dr.format_mac(self._discovered_device.mac), raise_on_progress=raise_on_progress, ) + if self._discovered_device.config.uses_http: + self._discovered_device.config.http_client = ( + create_async_tplink_clientsession(self.hass) + ) + await self._discovered_device.update() return self._discovered_device async def _async_try_connect( @@ -496,7 +535,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): device = await self._async_try_discover_and_update( host, credentials=credentials, - raise_on_progress=True, + raise_on_progress=False, + raise_on_timeout=False, + ) or await self._async_try_connect_all( + host, credentials=credentials, raise_on_progress=False ) except AuthenticationError as ex: errors[CONF_PASSWORD] = "invalid_auth" @@ -505,15 +547,23 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): errors["base"] = "cannot_connect" placeholders["error"] = str(ex) else: - await set_credentials(self.hass, username, password) - if updates := self._get_config_updates(reauth_entry, host, device): - self.hass.config_entries.async_update_entry( - reauth_entry, data=updates + if not device: + errors["base"] = "cannot_connect" + placeholders["error"] = "try_connect_all failed" + else: + await self.async_set_unique_id( + dr.format_mac(device.mac), + raise_on_progress=False, ) - self.hass.async_create_task( - self._async_reload_requires_auth_entries(), eager_start=False - ) - return self.async_abort(reason="reauth_successful") + await set_credentials(self.hass, username, password) + if updates := self._get_config_updates(reauth_entry, host, device): + self.hass.config_entries.async_update_entry( + reauth_entry, data=updates + ) + self.hass.async_create_task( + self._async_reload_requires_auth_entries(), eager_start=False + ) + return self.async_abort(reason="reauth_successful") # Old config entries will not have these values. alias = entry_data.get(CONF_ALIAS) or "unknown" diff --git a/tests/components/tplink/conftest.py b/tests/components/tplink/conftest.py index f1586ee4a0a..78cc9304bf7 100644 --- a/tests/components/tplink/conftest.py +++ b/tests/components/tplink/conftest.py @@ -32,6 +32,7 @@ def mock_discovery(): "homeassistant.components.tplink.Discover", discover=DEFAULT, discover_single=DEFAULT, + try_connect_all=DEFAULT, ) as mock_discovery: device = _mocked_device( device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), @@ -47,6 +48,7 @@ def mock_discovery(): } mock_discovery["discover"].return_value = devices mock_discovery["discover_single"].return_value = device + mock_discovery["try_connect_all"].return_value = device mock_discovery["mock_device"] = device yield mock_discovery diff --git a/tests/components/tplink/test_config_flow.py b/tests/components/tplink/test_config_flow.py index 40bd4383513..12a5741058c 100644 --- a/tests/components/tplink/test_config_flow.py +++ b/tests/components/tplink/test_config_flow.py @@ -1023,6 +1023,30 @@ async def test_dhcp_discovery_with_ip_change( assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" +async def test_dhcp_discovery_discover_fail( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_discovery: AsyncMock, + mock_connect: AsyncMock, +) -> None: + """Test dhcp discovery source cannot discover_single.""" + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 0 + assert mock_config_entry.data[CONF_HOST] == "127.0.0.1" + + with override_side_effect(mock_discovery["discover_single"], TimeoutError): + discovery_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_DHCP}, + data=dhcp.DhcpServiceInfo( + ip="127.0.0.2", macaddress=DHCP_FORMATTED_MAC_ADDRESS, hostname=ALIAS + ), + ) + assert discovery_result["type"] is FlowResultType.ABORT + assert discovery_result["reason"] == "cannot_connect" + + async def test_reauth( hass: HomeAssistant, mock_added_config_entry: MockConfigEntry, @@ -1057,6 +1081,76 @@ async def test_reauth( await hass.async_block_till_done() +async def test_reauth_try_connect_all( + hass: HomeAssistant, + mock_added_config_entry: MockConfigEntry, + mock_discovery: AsyncMock, + mock_connect: AsyncMock, +) -> None: + """Test reauth flow.""" + mock_added_config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + + assert mock_added_config_entry.state is ConfigEntryState.LOADED + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + [result] = flows + assert result["step_id"] == "reauth_confirm" + + with override_side_effect(mock_discovery["discover_single"], TimeoutError): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) + credentials = Credentials("fake_username", "fake_password") + mock_discovery["discover_single"].assert_called_once_with( + "127.0.0.1", credentials=credentials + ) + mock_discovery["try_connect_all"].assert_called_once() + assert result2["type"] is FlowResultType.ABORT + assert result2["reason"] == "reauth_successful" + + await hass.async_block_till_done() + + +async def test_reauth_try_connect_all_fail( + hass: HomeAssistant, + mock_added_config_entry: MockConfigEntry, + mock_discovery: AsyncMock, + mock_connect: AsyncMock, +) -> None: + """Test reauth flow.""" + mock_added_config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + + assert mock_added_config_entry.state is ConfigEntryState.LOADED + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + [result] = flows + assert result["step_id"] == "reauth_confirm" + + with ( + override_side_effect(mock_discovery["discover_single"], TimeoutError), + override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None), + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) + credentials = Credentials("fake_username", "fake_password") + mock_discovery["discover_single"].assert_called_once_with( + "127.0.0.1", credentials=credentials + ) + mock_discovery["try_connect_all"].assert_called_once() + assert result2["errors"] == {"base": "cannot_connect"} + + async def test_reauth_update_with_encryption_change( hass: HomeAssistant, mock_discovery: AsyncMock, @@ -1398,7 +1492,7 @@ async def test_pick_device_errors( assert result4["context"]["unique_id"] == MAC_ADDRESS -async def test_discovery_timeout_connect( +async def test_discovery_timeout_try_connect_all( hass: HomeAssistant, mock_discovery: AsyncMock, mock_connect: AsyncMock, @@ -1424,7 +1518,7 @@ async def test_discovery_timeout_connect( assert mock_connect["connect"].call_count == 1 -async def test_discovery_timeout_connect_legacy_error( +async def test_discovery_timeout_try_connect_all_needs_creds( hass: HomeAssistant, mock_discovery: AsyncMock, mock_connect: AsyncMock, @@ -1446,8 +1540,57 @@ async def test_discovery_timeout_connect_legacy_error( result["flow_id"], {CONF_HOST: IP_ADDRESS} ) await hass.async_block_till_done() + assert result2["step_id"] == "user_auth_confirm" assert result2["type"] is FlowResultType.FORM - assert result2["errors"] == {"base": "cannot_connect"} + + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) + await hass.async_block_till_done() + assert result3["type"] is FlowResultType.CREATE_ENTRY + assert result3["context"]["unique_id"] == MAC_ADDRESS + assert mock_connect["connect"].call_count == 1 + + +async def test_discovery_timeout_try_connect_all_fail( + hass: HomeAssistant, + mock_discovery: AsyncMock, + mock_connect: AsyncMock, + mock_init, +) -> None: + """Test discovery tries legacy connect on timeout.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + mock_discovery["discover_single"].side_effect = TimeoutError + await hass.async_block_till_done() + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + assert not result["errors"] + assert mock_connect["connect"].call_count == 0 + + with override_side_effect(mock_connect["connect"], KasaException): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_HOST: IP_ADDRESS} + ) + await hass.async_block_till_done() + assert result2["step_id"] == "user_auth_confirm" + assert result2["type"] is FlowResultType.FORM + + with override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None): + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) + await hass.async_block_till_done() + assert result3["errors"] == {"base": "cannot_connect"} assert mock_connect["connect"].call_count == 1