Use new try_connect_all discover command in tplink config flow (#128994)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Steven B. 2024-10-29 20:26:34 +00:00 committed by GitHub
parent aaf3039967
commit 46ceccfbb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 240 additions and 45 deletions

View File

@ -162,12 +162,16 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="already_in_progress") return self.async_abort(reason="already_in_progress")
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:
# If integration discovery there will be a device or None for dhcp
if device: if device:
self._discovered_device = device self._discovered_device = device
await self._async_try_connect(device, credentials) await self._async_try_connect(device, credentials)
else: else:
await self._async_try_discover_and_update( await self._async_try_discover_and_update(
host, credentials, raise_on_progress=True host,
credentials,
raise_on_progress=True,
raise_on_timeout=True,
) )
except AuthenticationError: except AuthenticationError:
return await self.async_step_discovery_auth_confirm() return await self.async_step_discovery_auth_confirm()
@ -271,7 +275,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:
device = await self._async_try_discover_and_update( device = await self._async_try_discover_and_update(
host, credentials, raise_on_progress=False host, credentials, raise_on_progress=False, raise_on_timeout=False
) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False
) )
except AuthenticationError: except AuthenticationError:
return await self.async_step_user_auth_confirm() return await self.async_step_user_auth_confirm()
@ -279,6 +285,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
placeholders["error"] = str(ex) placeholders["error"] = str(ex)
else: else:
if not device:
return await self.async_step_user_auth_confirm()
return self._async_create_entry_from_device(device) return self._async_create_entry_from_device(device)
return self.async_show_form( return self.async_show_form(
@ -298,21 +306,30 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
assert self.host is not None assert self.host is not None
placeholders: dict[str, str] = {CONF_HOST: self.host} placeholders: dict[str, str] = {CONF_HOST: self.host}
assert self._discovered_device is not None
if user_input: if user_input:
username = user_input[CONF_USERNAME] username = user_input[CONF_USERNAME]
password = user_input[CONF_PASSWORD] password = user_input[CONF_PASSWORD]
credentials = Credentials(username, password) credentials = Credentials(username, password)
device: Device | None
try: try:
if self._discovered_device:
device = await self._async_try_connect( device = await self._async_try_connect(
self._discovered_device, credentials self._discovered_device, credentials
) )
else:
device = await self._async_try_connect_all(
self.host, credentials=credentials, raise_on_progress=False
)
except AuthenticationError as ex: except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth" errors[CONF_PASSWORD] = "invalid_auth"
placeholders["error"] = str(ex) placeholders["error"] = str(ex)
except KasaException as ex: except KasaException as ex:
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
placeholders["error"] = str(ex) placeholders["error"] = str(ex)
else:
if not device:
errors["base"] = "cannot_connect"
placeholders["error"] = "try_connect_all failed"
else: else:
await set_credentials(self.hass, username, password) await set_credentials(self.hass, username, password)
self.hass.async_create_task( self.hass.async_create_task(
@ -408,46 +425,68 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
data=data, data=data,
) )
async def _async_try_connect_all(
self,
host: str,
credentials: Credentials | None,
raise_on_progress: bool,
) -> Device | None:
"""Try to connect to the device speculatively.
The connection parameters aren't known but discovery has failed so try
to connect with tcp.
"""
if credentials:
device = await Discover.try_connect_all(
host,
credentials=credentials,
http_client=create_async_tplink_clientsession(self.hass),
)
else:
# This will just try the legacy protocol that doesn't require auth
# and doesn't use http
try:
device = await Device.connect(config=DeviceConfig(host))
except Exception: # noqa: BLE001
return None
if device:
await self.async_set_unique_id(
dr.format_mac(device.mac),
raise_on_progress=raise_on_progress,
)
return device
async def _async_try_discover_and_update( async def _async_try_discover_and_update(
self, self,
host: str, host: str,
credentials: Credentials | None, credentials: Credentials | None,
raise_on_progress: bool, raise_on_progress: bool,
) -> Device: raise_on_timeout: bool,
) -> Device | None:
"""Try to discover the device and call update. """Try to discover the device and call update.
Will try to connect to legacy devices if discovery fails. Will try to connect directly if discovery fails.
""" """
self._discovered_device = None
try: try:
self._discovered_device = await Discover.discover_single( self._discovered_device = await Discover.discover_single(
host, credentials=credentials host, credentials=credentials
) )
except TimeoutError as ex: except TimeoutError as ex:
# Try connect() to legacy devices if discovery fails. This is a if raise_on_timeout:
# fallback mechanism for legacy that can handle connections without
# discovery info but if it fails raise the original error which is
# applicable for newer devices.
try:
self._discovered_device = await Device.connect(
config=DeviceConfig(host)
)
except Exception: # noqa: BLE001
# Raise the original error instead of the fallback error
raise ex from ex raise ex from ex
else: return None
if TYPE_CHECKING: if TYPE_CHECKING:
# device or exception is always returned unless
# on_unsupported callback was passed to discover_single
assert self._discovered_device assert self._discovered_device
await self.async_set_unique_id(
dr.format_mac(self._discovered_device.mac),
raise_on_progress=raise_on_progress,
)
if self._discovered_device.config.uses_http: if self._discovered_device.config.uses_http:
self._discovered_device.config.http_client = ( self._discovered_device.config.http_client = (
create_async_tplink_clientsession(self.hass) create_async_tplink_clientsession(self.hass)
) )
await self._discovered_device.update() await self._discovered_device.update()
await self.async_set_unique_id(
dr.format_mac(self._discovered_device.mac),
raise_on_progress=raise_on_progress,
)
return self._discovered_device return self._discovered_device
async def _async_try_connect( async def _async_try_connect(
@ -496,7 +535,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
device = await self._async_try_discover_and_update( device = await self._async_try_discover_and_update(
host, host,
credentials=credentials, credentials=credentials,
raise_on_progress=True, raise_on_progress=False,
raise_on_timeout=False,
) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False
) )
except AuthenticationError as ex: except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth" errors[CONF_PASSWORD] = "invalid_auth"
@ -505,6 +547,14 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
placeholders["error"] = str(ex) placeholders["error"] = str(ex)
else: else:
if not device:
errors["base"] = "cannot_connect"
placeholders["error"] = "try_connect_all failed"
else:
await self.async_set_unique_id(
dr.format_mac(device.mac),
raise_on_progress=False,
)
await set_credentials(self.hass, username, password) await set_credentials(self.hass, username, password)
if updates := self._get_config_updates(reauth_entry, host, device): if updates := self._get_config_updates(reauth_entry, host, device):
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(

View File

@ -32,6 +32,7 @@ def mock_discovery():
"homeassistant.components.tplink.Discover", "homeassistant.components.tplink.Discover",
discover=DEFAULT, discover=DEFAULT,
discover_single=DEFAULT, discover_single=DEFAULT,
try_connect_all=DEFAULT,
) as mock_discovery: ) as mock_discovery:
device = _mocked_device( device = _mocked_device(
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
@ -47,6 +48,7 @@ def mock_discovery():
} }
mock_discovery["discover"].return_value = devices mock_discovery["discover"].return_value = devices
mock_discovery["discover_single"].return_value = device mock_discovery["discover_single"].return_value = device
mock_discovery["try_connect_all"].return_value = device
mock_discovery["mock_device"] = device mock_discovery["mock_device"] = device
yield mock_discovery yield mock_discovery

View File

@ -1023,6 +1023,30 @@ async def test_dhcp_discovery_with_ip_change(
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
async def test_dhcp_discovery_discover_fail(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
) -> None:
"""Test dhcp discovery source cannot discover_single."""
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 0
assert mock_config_entry.data[CONF_HOST] == "127.0.0.1"
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DHCP},
data=dhcp.DhcpServiceInfo(
ip="127.0.0.2", macaddress=DHCP_FORMATTED_MAC_ADDRESS, hostname=ALIAS
),
)
assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "cannot_connect"
async def test_reauth( async def test_reauth(
hass: HomeAssistant, hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry, mock_added_config_entry: MockConfigEntry,
@ -1057,6 +1081,76 @@ async def test_reauth(
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_reauth_try_connect_all(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
) -> None:
"""Test reauth flow."""
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
assert mock_added_config_entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
)
mock_discovery["try_connect_all"].assert_called_once()
assert result2["type"] is FlowResultType.ABORT
assert result2["reason"] == "reauth_successful"
await hass.async_block_till_done()
async def test_reauth_try_connect_all_fail(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
) -> None:
"""Test reauth flow."""
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
assert mock_added_config_entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
with (
override_side_effect(mock_discovery["discover_single"], TimeoutError),
override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None),
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials
)
mock_discovery["try_connect_all"].assert_called_once()
assert result2["errors"] == {"base": "cannot_connect"}
async def test_reauth_update_with_encryption_change( async def test_reauth_update_with_encryption_change(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock, mock_discovery: AsyncMock,
@ -1398,7 +1492,7 @@ async def test_pick_device_errors(
assert result4["context"]["unique_id"] == MAC_ADDRESS assert result4["context"]["unique_id"] == MAC_ADDRESS
async def test_discovery_timeout_connect( async def test_discovery_timeout_try_connect_all(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock, mock_discovery: AsyncMock,
mock_connect: AsyncMock, mock_connect: AsyncMock,
@ -1424,7 +1518,7 @@ async def test_discovery_timeout_connect(
assert mock_connect["connect"].call_count == 1 assert mock_connect["connect"].call_count == 1
async def test_discovery_timeout_connect_legacy_error( async def test_discovery_timeout_try_connect_all_needs_creds(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock, mock_discovery: AsyncMock,
mock_connect: AsyncMock, mock_connect: AsyncMock,
@ -1446,8 +1540,57 @@ async def test_discovery_timeout_connect_legacy_error(
result["flow_id"], {CONF_HOST: IP_ADDRESS} result["flow_id"], {CONF_HOST: IP_ADDRESS}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["step_id"] == "user_auth_confirm"
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": "cannot_connect"}
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
await hass.async_block_till_done()
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["context"]["unique_id"] == MAC_ADDRESS
assert mock_connect["connect"].call_count == 1
async def test_discovery_timeout_try_connect_all_fail(
hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock,
mock_init,
) -> None:
"""Test discovery tries legacy connect on timeout."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_discovery["discover_single"].side_effect = TimeoutError
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
assert mock_connect["connect"].call_count == 0
with override_side_effect(mock_connect["connect"], KasaException):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_HOST: IP_ADDRESS}
)
await hass.async_block_till_done()
assert result2["step_id"] == "user_auth_confirm"
assert result2["type"] is FlowResultType.FORM
with override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None):
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
await hass.async_block_till_done()
assert result3["errors"] == {"base": "cannot_connect"}
assert mock_connect["connect"].call_count == 1 assert mock_connect["connect"].call_count == 1