mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Use new try_connect_all discover command in tplink config flow (#128994)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
aaf3039967
commit
46ceccfbb3
@ -162,12 +162,16 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
return self.async_abort(reason="already_in_progress")
|
||||
credentials = await get_credentials(self.hass)
|
||||
try:
|
||||
# If integration discovery there will be a device or None for dhcp
|
||||
if device:
|
||||
self._discovered_device = device
|
||||
await self._async_try_connect(device, credentials)
|
||||
else:
|
||||
await self._async_try_discover_and_update(
|
||||
host, credentials, raise_on_progress=True
|
||||
host,
|
||||
credentials,
|
||||
raise_on_progress=True,
|
||||
raise_on_timeout=True,
|
||||
)
|
||||
except AuthenticationError:
|
||||
return await self.async_step_discovery_auth_confirm()
|
||||
@ -271,7 +275,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
credentials = await get_credentials(self.hass)
|
||||
try:
|
||||
device = await self._async_try_discover_and_update(
|
||||
host, credentials, raise_on_progress=False
|
||||
host, credentials, raise_on_progress=False, raise_on_timeout=False
|
||||
) or await self._async_try_connect_all(
|
||||
host, credentials=credentials, raise_on_progress=False
|
||||
)
|
||||
except AuthenticationError:
|
||||
return await self.async_step_user_auth_confirm()
|
||||
@ -279,6 +285,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors["base"] = "cannot_connect"
|
||||
placeholders["error"] = str(ex)
|
||||
else:
|
||||
if not device:
|
||||
return await self.async_step_user_auth_confirm()
|
||||
return self._async_create_entry_from_device(device)
|
||||
|
||||
return self.async_show_form(
|
||||
@ -298,15 +306,20 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
assert self.host is not None
|
||||
placeholders: dict[str, str] = {CONF_HOST: self.host}
|
||||
|
||||
assert self._discovered_device is not None
|
||||
if user_input:
|
||||
username = user_input[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
credentials = Credentials(username, password)
|
||||
device: Device | None
|
||||
try:
|
||||
device = await self._async_try_connect(
|
||||
self._discovered_device, credentials
|
||||
)
|
||||
if self._discovered_device:
|
||||
device = await self._async_try_connect(
|
||||
self._discovered_device, credentials
|
||||
)
|
||||
else:
|
||||
device = await self._async_try_connect_all(
|
||||
self.host, credentials=credentials, raise_on_progress=False
|
||||
)
|
||||
except AuthenticationError as ex:
|
||||
errors[CONF_PASSWORD] = "invalid_auth"
|
||||
placeholders["error"] = str(ex)
|
||||
@ -314,11 +327,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors["base"] = "cannot_connect"
|
||||
placeholders["error"] = str(ex)
|
||||
else:
|
||||
await set_credentials(self.hass, username, password)
|
||||
self.hass.async_create_task(
|
||||
self._async_reload_requires_auth_entries(), eager_start=False
|
||||
)
|
||||
return self._async_create_entry_from_device(device)
|
||||
if not device:
|
||||
errors["base"] = "cannot_connect"
|
||||
placeholders["error"] = "try_connect_all failed"
|
||||
else:
|
||||
await set_credentials(self.hass, username, password)
|
||||
self.hass.async_create_task(
|
||||
self._async_reload_requires_auth_entries(), eager_start=False
|
||||
)
|
||||
return self._async_create_entry_from_device(device)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user_auth_confirm",
|
||||
@ -408,46 +425,68 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def _async_try_connect_all(
|
||||
self,
|
||||
host: str,
|
||||
credentials: Credentials | None,
|
||||
raise_on_progress: bool,
|
||||
) -> Device | None:
|
||||
"""Try to connect to the device speculatively.
|
||||
|
||||
The connection parameters aren't known but discovery has failed so try
|
||||
to connect with tcp.
|
||||
"""
|
||||
if credentials:
|
||||
device = await Discover.try_connect_all(
|
||||
host,
|
||||
credentials=credentials,
|
||||
http_client=create_async_tplink_clientsession(self.hass),
|
||||
)
|
||||
else:
|
||||
# This will just try the legacy protocol that doesn't require auth
|
||||
# and doesn't use http
|
||||
try:
|
||||
device = await Device.connect(config=DeviceConfig(host))
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
if device:
|
||||
await self.async_set_unique_id(
|
||||
dr.format_mac(device.mac),
|
||||
raise_on_progress=raise_on_progress,
|
||||
)
|
||||
return device
|
||||
|
||||
async def _async_try_discover_and_update(
|
||||
self,
|
||||
host: str,
|
||||
credentials: Credentials | None,
|
||||
raise_on_progress: bool,
|
||||
) -> Device:
|
||||
raise_on_timeout: bool,
|
||||
) -> Device | None:
|
||||
"""Try to discover the device and call update.
|
||||
|
||||
Will try to connect to legacy devices if discovery fails.
|
||||
Will try to connect directly if discovery fails.
|
||||
"""
|
||||
self._discovered_device = None
|
||||
try:
|
||||
self._discovered_device = await Discover.discover_single(
|
||||
host, credentials=credentials
|
||||
)
|
||||
except TimeoutError as ex:
|
||||
# Try connect() to legacy devices if discovery fails. This is a
|
||||
# fallback mechanism for legacy that can handle connections without
|
||||
# discovery info but if it fails raise the original error which is
|
||||
# applicable for newer devices.
|
||||
try:
|
||||
self._discovered_device = await Device.connect(
|
||||
config=DeviceConfig(host)
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
# Raise the original error instead of the fallback error
|
||||
if raise_on_timeout:
|
||||
raise ex from ex
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
# device or exception is always returned unless
|
||||
# on_unsupported callback was passed to discover_single
|
||||
assert self._discovered_device
|
||||
if self._discovered_device.config.uses_http:
|
||||
self._discovered_device.config.http_client = (
|
||||
create_async_tplink_clientsession(self.hass)
|
||||
)
|
||||
await self._discovered_device.update()
|
||||
return None
|
||||
if TYPE_CHECKING:
|
||||
assert self._discovered_device
|
||||
await self.async_set_unique_id(
|
||||
dr.format_mac(self._discovered_device.mac),
|
||||
raise_on_progress=raise_on_progress,
|
||||
)
|
||||
if self._discovered_device.config.uses_http:
|
||||
self._discovered_device.config.http_client = (
|
||||
create_async_tplink_clientsession(self.hass)
|
||||
)
|
||||
await self._discovered_device.update()
|
||||
return self._discovered_device
|
||||
|
||||
async def _async_try_connect(
|
||||
@ -496,7 +535,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
device = await self._async_try_discover_and_update(
|
||||
host,
|
||||
credentials=credentials,
|
||||
raise_on_progress=True,
|
||||
raise_on_progress=False,
|
||||
raise_on_timeout=False,
|
||||
) or await self._async_try_connect_all(
|
||||
host, credentials=credentials, raise_on_progress=False
|
||||
)
|
||||
except AuthenticationError as ex:
|
||||
errors[CONF_PASSWORD] = "invalid_auth"
|
||||
@ -505,15 +547,23 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors["base"] = "cannot_connect"
|
||||
placeholders["error"] = str(ex)
|
||||
else:
|
||||
await set_credentials(self.hass, username, password)
|
||||
if updates := self._get_config_updates(reauth_entry, host, device):
|
||||
self.hass.config_entries.async_update_entry(
|
||||
reauth_entry, data=updates
|
||||
if not device:
|
||||
errors["base"] = "cannot_connect"
|
||||
placeholders["error"] = "try_connect_all failed"
|
||||
else:
|
||||
await self.async_set_unique_id(
|
||||
dr.format_mac(device.mac),
|
||||
raise_on_progress=False,
|
||||
)
|
||||
self.hass.async_create_task(
|
||||
self._async_reload_requires_auth_entries(), eager_start=False
|
||||
)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
await set_credentials(self.hass, username, password)
|
||||
if updates := self._get_config_updates(reauth_entry, host, device):
|
||||
self.hass.config_entries.async_update_entry(
|
||||
reauth_entry, data=updates
|
||||
)
|
||||
self.hass.async_create_task(
|
||||
self._async_reload_requires_auth_entries(), eager_start=False
|
||||
)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
# Old config entries will not have these values.
|
||||
alias = entry_data.get(CONF_ALIAS) or "unknown"
|
||||
|
@ -32,6 +32,7 @@ def mock_discovery():
|
||||
"homeassistant.components.tplink.Discover",
|
||||
discover=DEFAULT,
|
||||
discover_single=DEFAULT,
|
||||
try_connect_all=DEFAULT,
|
||||
) as mock_discovery:
|
||||
device = _mocked_device(
|
||||
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
|
||||
@ -47,6 +48,7 @@ def mock_discovery():
|
||||
}
|
||||
mock_discovery["discover"].return_value = devices
|
||||
mock_discovery["discover_single"].return_value = device
|
||||
mock_discovery["try_connect_all"].return_value = device
|
||||
mock_discovery["mock_device"] = device
|
||||
yield mock_discovery
|
||||
|
||||
|
@ -1023,6 +1023,30 @@ async def test_dhcp_discovery_with_ip_change(
|
||||
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
|
||||
|
||||
|
||||
async def test_dhcp_discovery_discover_fail(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_discovery: AsyncMock,
|
||||
mock_connect: AsyncMock,
|
||||
) -> None:
|
||||
"""Test dhcp discovery source cannot discover_single."""
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 0
|
||||
assert mock_config_entry.data[CONF_HOST] == "127.0.0.1"
|
||||
|
||||
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
|
||||
discovery_result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_DHCP},
|
||||
data=dhcp.DhcpServiceInfo(
|
||||
ip="127.0.0.2", macaddress=DHCP_FORMATTED_MAC_ADDRESS, hostname=ALIAS
|
||||
),
|
||||
)
|
||||
assert discovery_result["type"] is FlowResultType.ABORT
|
||||
assert discovery_result["reason"] == "cannot_connect"
|
||||
|
||||
|
||||
async def test_reauth(
|
||||
hass: HomeAssistant,
|
||||
mock_added_config_entry: MockConfigEntry,
|
||||
@ -1057,6 +1081,76 @@ async def test_reauth(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_reauth_try_connect_all(
|
||||
hass: HomeAssistant,
|
||||
mock_added_config_entry: MockConfigEntry,
|
||||
mock_discovery: AsyncMock,
|
||||
mock_connect: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reauth flow."""
|
||||
mock_added_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_added_config_entry.state is ConfigEntryState.LOADED
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
[result] = flows
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
CONF_USERNAME: "fake_username",
|
||||
CONF_PASSWORD: "fake_password",
|
||||
},
|
||||
)
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
)
|
||||
mock_discovery["try_connect_all"].assert_called_once()
|
||||
assert result2["type"] is FlowResultType.ABORT
|
||||
assert result2["reason"] == "reauth_successful"
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_reauth_try_connect_all_fail(
|
||||
hass: HomeAssistant,
|
||||
mock_added_config_entry: MockConfigEntry,
|
||||
mock_discovery: AsyncMock,
|
||||
mock_connect: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reauth flow."""
|
||||
mock_added_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_added_config_entry.state is ConfigEntryState.LOADED
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
[result] = flows
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
with (
|
||||
override_side_effect(mock_discovery["discover_single"], TimeoutError),
|
||||
override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None),
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
CONF_USERNAME: "fake_username",
|
||||
CONF_PASSWORD: "fake_password",
|
||||
},
|
||||
)
|
||||
credentials = Credentials("fake_username", "fake_password")
|
||||
mock_discovery["discover_single"].assert_called_once_with(
|
||||
"127.0.0.1", credentials=credentials
|
||||
)
|
||||
mock_discovery["try_connect_all"].assert_called_once()
|
||||
assert result2["errors"] == {"base": "cannot_connect"}
|
||||
|
||||
|
||||
async def test_reauth_update_with_encryption_change(
|
||||
hass: HomeAssistant,
|
||||
mock_discovery: AsyncMock,
|
||||
@ -1398,7 +1492,7 @@ async def test_pick_device_errors(
|
||||
assert result4["context"]["unique_id"] == MAC_ADDRESS
|
||||
|
||||
|
||||
async def test_discovery_timeout_connect(
|
||||
async def test_discovery_timeout_try_connect_all(
|
||||
hass: HomeAssistant,
|
||||
mock_discovery: AsyncMock,
|
||||
mock_connect: AsyncMock,
|
||||
@ -1424,7 +1518,7 @@ async def test_discovery_timeout_connect(
|
||||
assert mock_connect["connect"].call_count == 1
|
||||
|
||||
|
||||
async def test_discovery_timeout_connect_legacy_error(
|
||||
async def test_discovery_timeout_try_connect_all_needs_creds(
|
||||
hass: HomeAssistant,
|
||||
mock_discovery: AsyncMock,
|
||||
mock_connect: AsyncMock,
|
||||
@ -1446,8 +1540,57 @@ async def test_discovery_timeout_connect_legacy_error(
|
||||
result["flow_id"], {CONF_HOST: IP_ADDRESS}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result2["step_id"] == "user_auth_confirm"
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": "cannot_connect"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
user_input={
|
||||
CONF_USERNAME: "fake_username",
|
||||
CONF_PASSWORD: "fake_password",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result3["context"]["unique_id"] == MAC_ADDRESS
|
||||
assert mock_connect["connect"].call_count == 1
|
||||
|
||||
|
||||
async def test_discovery_timeout_try_connect_all_fail(
|
||||
hass: HomeAssistant,
|
||||
mock_discovery: AsyncMock,
|
||||
mock_connect: AsyncMock,
|
||||
mock_init,
|
||||
) -> None:
|
||||
"""Test discovery tries legacy connect on timeout."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
mock_discovery["discover_single"].side_effect = TimeoutError
|
||||
await hass.async_block_till_done()
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "user"
|
||||
assert not result["errors"]
|
||||
assert mock_connect["connect"].call_count == 0
|
||||
|
||||
with override_side_effect(mock_connect["connect"], KasaException):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_HOST: IP_ADDRESS}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result2["step_id"] == "user_auth_confirm"
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
|
||||
with override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None):
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
user_input={
|
||||
CONF_USERNAME: "fake_username",
|
||||
CONF_PASSWORD: "fake_password",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result3["errors"] == {"base": "cannot_connect"}
|
||||
assert mock_connect["connect"].call_count == 1
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user