Use new try_connect_all discover command in tplink config flow (#128994)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Steven B. 2024-10-29 20:26:34 +00:00 committed by GitHub
parent aaf3039967
commit 46ceccfbb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 240 additions and 45 deletions

View File

@ -162,12 +162,16 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="already_in_progress")
credentials = await get_credentials(self.hass)
try:
# If integration discovery there will be a device or None for dhcp
if device:
self._discovered_device = device
await self._async_try_connect(device, credentials)
else:
await self._async_try_discover_and_update(
host, credentials, raise_on_progress=True
host,
credentials,
raise_on_progress=True,
raise_on_timeout=True,
)
except AuthenticationError:
return await self.async_step_discovery_auth_confirm()
@ -271,7 +275,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
credentials = await get_credentials(self.hass)
try:
device = await self._async_try_discover_and_update(
host, credentials, raise_on_progress=False
host, credentials, raise_on_progress=False, raise_on_timeout=False
) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False
)
except AuthenticationError:
return await self.async_step_user_auth_confirm()
@ -279,6 +285,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "cannot_connect"
placeholders["error"] = str(ex)
else:
if not device:
return await self.async_step_user_auth_confirm()
return self._async_create_entry_from_device(device)
return self.async_show_form(
@ -298,15 +306,20 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
assert self.host is not None
placeholders: dict[str, str] = {CONF_HOST: self.host}
assert self._discovered_device is not None
if user_input:
username = user_input[CONF_USERNAME]
password = user_input[CONF_PASSWORD]
credentials = Credentials(username, password)
device: Device | None
try:
device = await self._async_try_connect(
self._discovered_device, credentials
)
if self._discovered_device:
device = await self._async_try_connect(
self._discovered_device, credentials
)
else:
device = await self._async_try_connect_all(
self.host, credentials=credentials, raise_on_progress=False
)
except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth"
placeholders["error"] = str(ex)
@ -314,11 +327,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "cannot_connect"
placeholders["error"] = str(ex)
else:
await set_credentials(self.hass, username, password)
self.hass.async_create_task(
self._async_reload_requires_auth_entries(), eager_start=False
)
return self._async_create_entry_from_device(device)
if not device:
errors["base"] = "cannot_connect"
placeholders["error"] = "try_connect_all failed"
else:
await set_credentials(self.hass, username, password)
self.hass.async_create_task(
self._async_reload_requires_auth_entries(), eager_start=False
)
return self._async_create_entry_from_device(device)
return self.async_show_form(
step_id="user_auth_confirm",
@ -408,46 +425,68 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
data=data,
)
async def _async_try_connect_all(
self,
host: str,
credentials: Credentials | None,
raise_on_progress: bool,
) -> Device | None:
"""Try to connect to the device speculatively.
The connection parameters aren't known but discovery has failed so try
to connect with tcp.
"""
if credentials:
device = await Discover.try_connect_all(
host,
credentials=credentials,
http_client=create_async_tplink_clientsession(self.hass),
)
else:
# This will just try the legacy protocol that doesn't require auth
# and doesn't use http
try:
device = await Device.connect(config=DeviceConfig(host))
except Exception: # noqa: BLE001
return None
if device:
await self.async_set_unique_id(
dr.format_mac(device.mac),
raise_on_progress=raise_on_progress,
)
return device
async def _async_try_discover_and_update(
self,
host: str,
credentials: Credentials | None,
raise_on_progress: bool,
) -> Device:
raise_on_timeout: bool,
) -> Device | None:
"""Try to discover the device and call update.
Will try to connect to legacy devices if discovery fails.
Will try to connect directly if discovery fails.
"""
self._discovered_device = None
try:
self._discovered_device = await Discover.discover_single(
host, credentials=credentials
)
except TimeoutError as ex:
# Try connect() to legacy devices if discovery fails. This is a
# fallback mechanism for legacy that can handle connections without
# discovery info but if it fails raise the original error which is
# applicable for newer devices.
try:
self._discovered_device = await Device.connect(
config=DeviceConfig(host)
)
except Exception: # noqa: BLE001
# Raise the original error instead of the fallback error
if raise_on_timeout:
raise ex from ex
else:
if TYPE_CHECKING:
# device or exception is always returned unless
# on_unsupported callback was passed to discover_single
assert self._discovered_device
if self._discovered_device.config.uses_http:
self._discovered_device.config.http_client = (
create_async_tplink_clientsession(self.hass)
)
await self._discovered_device.update()
return None
if TYPE_CHECKING:
assert self._discovered_device
await self.async_set_unique_id(
dr.format_mac(self._discovered_device.mac),
raise_on_progress=raise_on_progress,
)
if self._discovered_device.config.uses_http:
self._discovered_device.config.http_client = (
create_async_tplink_clientsession(self.hass)
)
await self._discovered_device.update()
return self._discovered_device
async def _async_try_connect(
@ -496,7 +535,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
device = await self._async_try_discover_and_update(
host,
credentials=credentials,
raise_on_progress=True,
raise_on_progress=False,
raise_on_timeout=False,
) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False
)
except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth"
@ -505,15 +547,23 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "cannot_connect"
placeholders["error"] = str(ex)
else:
await set_credentials(self.hass, username, password)
if updates := self._get_config_updates(reauth_entry, host, device):
self.hass.config_entries.async_update_entry(
reauth_entry, data=updates
if not device:
errors["base"] = "cannot_connect"
placeholders["error"] = "try_connect_all failed"
else:
await self.async_set_unique_id(
dr.format_mac(device.mac),
raise_on_progress=False,
)
self.hass.async_create_task(
self._async_reload_requires_auth_entries(), eager_start=False
)
return self.async_abort(reason="reauth_successful")
await set_credentials(self.hass, username, password)
if updates := self._get_config_updates(reauth_entry, host, device):
self.hass.config_entries.async_update_entry(
reauth_entry, data=updates
)
self.hass.async_create_task(
self._async_reload_requires_auth_entries(), eager_start=False
)
return self.async_abort(reason="reauth_successful")
# Old config entries will not have these values.
alias = entry_data.get(CONF_ALIAS) or "unknown"

View File

@ -32,6 +32,7 @@ def mock_discovery():
"homeassistant.components.tplink.Discover",
discover=DEFAULT,
discover_single=DEFAULT,
try_connect_all=DEFAULT,
) as mock_discovery:
device = _mocked_device(
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
@ -47,6 +48,7 @@ def mock_discovery():
}
mock_discovery["discover"].return_value = devices
mock_discovery["discover_single"].return_value = device
mock_discovery["try_connect_all"].return_value = device
mock_discovery["mock_device"] = device
yield mock_discovery

View File

@ -1023,6 +1023,30 @@ async def test_dhcp_discovery_with_ip_change(
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
async def test_dhcp_discovery_discover_fail(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
) -> None:
"""Test dhcp discovery source cannot discover_single."""
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 0
assert mock_config_entry.data[CONF_HOST] == "127.0.0.1"
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DHCP},
data=dhcp.DhcpServiceInfo(
ip="127.0.0.2", macaddress=DHCP_FORMATTED_MAC_ADDRESS, hostname=ALIAS
),
)
assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "cannot_connect"
async def test_reauth(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
@ -1057,6 +1081,76 @@ async def test_reauth(
await hass.async_block_till_done()
async def test_reauth_try_connect_all(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
) -> None:
"""Test reauth flow."""
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
assert mock_added_config_entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
)
mock_discovery["try_connect_all"].assert_called_once()
assert result2["type"] is FlowResultType.ABORT
assert result2["reason"] == "reauth_successful"
await hass.async_block_till_done()
async def test_reauth_try_connect_all_fail(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
) -> None:
"""Test reauth flow."""
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
assert mock_added_config_entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
with (
override_side_effect(mock_discovery["discover_single"], TimeoutError),
override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None),
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
)
mock_discovery["try_connect_all"].assert_called_once()
assert result2["errors"] == {"base": "cannot_connect"}
async def test_reauth_update_with_encryption_change(
hass: HomeAssistant,
mock_discovery: AsyncMock,
@ -1398,7 +1492,7 @@ async def test_pick_device_errors(
assert result4["context"]["unique_id"] == MAC_ADDRESS
async def test_discovery_timeout_connect(
async def test_discovery_timeout_try_connect_all(
hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
@ -1424,7 +1518,7 @@ async def test_discovery_timeout_connect(
assert mock_connect["connect"].call_count == 1
async def test_discovery_timeout_connect_legacy_error(
async def test_discovery_timeout_try_connect_all_needs_creds(
hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
@ -1446,8 +1540,57 @@ async def test_discovery_timeout_connect_legacy_error(
result["flow_id"], {CONF_HOST: IP_ADDRESS}
)
await hass.async_block_till_done()
assert result2["step_id"] == "user_auth_confirm"
assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": "cannot_connect"}
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
await hass.async_block_till_done()
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["context"]["unique_id"] == MAC_ADDRESS
assert mock_connect["connect"].call_count == 1
async def test_discovery_timeout_try_connect_all_fail(
hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
mock_init,
) -> None:
"""Test discovery tries legacy connect on timeout."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_discovery["discover_single"].side_effect = TimeoutError
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
assert mock_connect["connect"].call_count == 0
with override_side_effect(mock_connect["connect"], KasaException):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_HOST: IP_ADDRESS}
)
await hass.async_block_till_done()
assert result2["step_id"] == "user_auth_confirm"
assert result2["type"] is FlowResultType.FORM
with override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None):
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
await hass.async_block_till_done()
assert result3["errors"] == {"base": "cannot_connect"}
assert mock_connect["connect"].call_count == 1