mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 10:17:09 +00:00
Use new try_connect_all discover command in tplink config flow (#128994)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
aaf3039967
commit
46ceccfbb3
@ -162,12 +162,16 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
return self.async_abort(reason="already_in_progress")
|
return self.async_abort(reason="already_in_progress")
|
||||||
credentials = await get_credentials(self.hass)
|
credentials = await get_credentials(self.hass)
|
||||||
try:
|
try:
|
||||||
|
# If integration discovery there will be a device or None for dhcp
|
||||||
if device:
|
if device:
|
||||||
self._discovered_device = device
|
self._discovered_device = device
|
||||||
await self._async_try_connect(device, credentials)
|
await self._async_try_connect(device, credentials)
|
||||||
else:
|
else:
|
||||||
await self._async_try_discover_and_update(
|
await self._async_try_discover_and_update(
|
||||||
host, credentials, raise_on_progress=True
|
host,
|
||||||
|
credentials,
|
||||||
|
raise_on_progress=True,
|
||||||
|
raise_on_timeout=True,
|
||||||
)
|
)
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
return await self.async_step_discovery_auth_confirm()
|
return await self.async_step_discovery_auth_confirm()
|
||||||
@ -271,7 +275,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
credentials = await get_credentials(self.hass)
|
credentials = await get_credentials(self.hass)
|
||||||
try:
|
try:
|
||||||
device = await self._async_try_discover_and_update(
|
device = await self._async_try_discover_and_update(
|
||||||
host, credentials, raise_on_progress=False
|
host, credentials, raise_on_progress=False, raise_on_timeout=False
|
||||||
|
) or await self._async_try_connect_all(
|
||||||
|
host, credentials=credentials, raise_on_progress=False
|
||||||
)
|
)
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
return await self.async_step_user_auth_confirm()
|
return await self.async_step_user_auth_confirm()
|
||||||
@ -279,6 +285,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
errors["base"] = "cannot_connect"
|
errors["base"] = "cannot_connect"
|
||||||
placeholders["error"] = str(ex)
|
placeholders["error"] = str(ex)
|
||||||
else:
|
else:
|
||||||
|
if not device:
|
||||||
|
return await self.async_step_user_auth_confirm()
|
||||||
return self._async_create_entry_from_device(device)
|
return self._async_create_entry_from_device(device)
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
@ -298,21 +306,30 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
assert self.host is not None
|
assert self.host is not None
|
||||||
placeholders: dict[str, str] = {CONF_HOST: self.host}
|
placeholders: dict[str, str] = {CONF_HOST: self.host}
|
||||||
|
|
||||||
assert self._discovered_device is not None
|
|
||||||
if user_input:
|
if user_input:
|
||||||
username = user_input[CONF_USERNAME]
|
username = user_input[CONF_USERNAME]
|
||||||
password = user_input[CONF_PASSWORD]
|
password = user_input[CONF_PASSWORD]
|
||||||
credentials = Credentials(username, password)
|
credentials = Credentials(username, password)
|
||||||
|
device: Device | None
|
||||||
try:
|
try:
|
||||||
|
if self._discovered_device:
|
||||||
device = await self._async_try_connect(
|
device = await self._async_try_connect(
|
||||||
self._discovered_device, credentials
|
self._discovered_device, credentials
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
device = await self._async_try_connect_all(
|
||||||
|
self.host, credentials=credentials, raise_on_progress=False
|
||||||
|
)
|
||||||
except AuthenticationError as ex:
|
except AuthenticationError as ex:
|
||||||
errors[CONF_PASSWORD] = "invalid_auth"
|
errors[CONF_PASSWORD] = "invalid_auth"
|
||||||
placeholders["error"] = str(ex)
|
placeholders["error"] = str(ex)
|
||||||
except KasaException as ex:
|
except KasaException as ex:
|
||||||
errors["base"] = "cannot_connect"
|
errors["base"] = "cannot_connect"
|
||||||
placeholders["error"] = str(ex)
|
placeholders["error"] = str(ex)
|
||||||
|
else:
|
||||||
|
if not device:
|
||||||
|
errors["base"] = "cannot_connect"
|
||||||
|
placeholders["error"] = "try_connect_all failed"
|
||||||
else:
|
else:
|
||||||
await set_credentials(self.hass, username, password)
|
await set_credentials(self.hass, username, password)
|
||||||
self.hass.async_create_task(
|
self.hass.async_create_task(
|
||||||
@ -408,46 +425,68 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_try_connect_all(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
credentials: Credentials | None,
|
||||||
|
raise_on_progress: bool,
|
||||||
|
) -> Device | None:
|
||||||
|
"""Try to connect to the device speculatively.
|
||||||
|
|
||||||
|
The connection parameters aren't known but discovery has failed so try
|
||||||
|
to connect with tcp.
|
||||||
|
"""
|
||||||
|
if credentials:
|
||||||
|
device = await Discover.try_connect_all(
|
||||||
|
host,
|
||||||
|
credentials=credentials,
|
||||||
|
http_client=create_async_tplink_clientsession(self.hass),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This will just try the legacy protocol that doesn't require auth
|
||||||
|
# and doesn't use http
|
||||||
|
try:
|
||||||
|
device = await Device.connect(config=DeviceConfig(host))
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return None
|
||||||
|
if device:
|
||||||
|
await self.async_set_unique_id(
|
||||||
|
dr.format_mac(device.mac),
|
||||||
|
raise_on_progress=raise_on_progress,
|
||||||
|
)
|
||||||
|
return device
|
||||||
|
|
||||||
async def _async_try_discover_and_update(
|
async def _async_try_discover_and_update(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
credentials: Credentials | None,
|
credentials: Credentials | None,
|
||||||
raise_on_progress: bool,
|
raise_on_progress: bool,
|
||||||
) -> Device:
|
raise_on_timeout: bool,
|
||||||
|
) -> Device | None:
|
||||||
"""Try to discover the device and call update.
|
"""Try to discover the device and call update.
|
||||||
|
|
||||||
Will try to connect to legacy devices if discovery fails.
|
Will try to connect directly if discovery fails.
|
||||||
"""
|
"""
|
||||||
|
self._discovered_device = None
|
||||||
try:
|
try:
|
||||||
self._discovered_device = await Discover.discover_single(
|
self._discovered_device = await Discover.discover_single(
|
||||||
host, credentials=credentials
|
host, credentials=credentials
|
||||||
)
|
)
|
||||||
except TimeoutError as ex:
|
except TimeoutError as ex:
|
||||||
# Try connect() to legacy devices if discovery fails. This is a
|
if raise_on_timeout:
|
||||||
# fallback mechanism for legacy that can handle connections without
|
|
||||||
# discovery info but if it fails raise the original error which is
|
|
||||||
# applicable for newer devices.
|
|
||||||
try:
|
|
||||||
self._discovered_device = await Device.connect(
|
|
||||||
config=DeviceConfig(host)
|
|
||||||
)
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
# Raise the original error instead of the fallback error
|
|
||||||
raise ex from ex
|
raise ex from ex
|
||||||
else:
|
return None
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# device or exception is always returned unless
|
|
||||||
# on_unsupported callback was passed to discover_single
|
|
||||||
assert self._discovered_device
|
assert self._discovered_device
|
||||||
|
await self.async_set_unique_id(
|
||||||
|
dr.format_mac(self._discovered_device.mac),
|
||||||
|
raise_on_progress=raise_on_progress,
|
||||||
|
)
|
||||||
if self._discovered_device.config.uses_http:
|
if self._discovered_device.config.uses_http:
|
||||||
self._discovered_device.config.http_client = (
|
self._discovered_device.config.http_client = (
|
||||||
create_async_tplink_clientsession(self.hass)
|
create_async_tplink_clientsession(self.hass)
|
||||||
)
|
)
|
||||||
await self._discovered_device.update()
|
await self._discovered_device.update()
|
||||||
await self.async_set_unique_id(
|
|
||||||
dr.format_mac(self._discovered_device.mac),
|
|
||||||
raise_on_progress=raise_on_progress,
|
|
||||||
)
|
|
||||||
return self._discovered_device
|
return self._discovered_device
|
||||||
|
|
||||||
async def _async_try_connect(
|
async def _async_try_connect(
|
||||||
@ -496,7 +535,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
device = await self._async_try_discover_and_update(
|
device = await self._async_try_discover_and_update(
|
||||||
host,
|
host,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
raise_on_progress=True,
|
raise_on_progress=False,
|
||||||
|
raise_on_timeout=False,
|
||||||
|
) or await self._async_try_connect_all(
|
||||||
|
host, credentials=credentials, raise_on_progress=False
|
||||||
)
|
)
|
||||||
except AuthenticationError as ex:
|
except AuthenticationError as ex:
|
||||||
errors[CONF_PASSWORD] = "invalid_auth"
|
errors[CONF_PASSWORD] = "invalid_auth"
|
||||||
@ -505,6 +547,14 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
errors["base"] = "cannot_connect"
|
errors["base"] = "cannot_connect"
|
||||||
placeholders["error"] = str(ex)
|
placeholders["error"] = str(ex)
|
||||||
else:
|
else:
|
||||||
|
if not device:
|
||||||
|
errors["base"] = "cannot_connect"
|
||||||
|
placeholders["error"] = "try_connect_all failed"
|
||||||
|
else:
|
||||||
|
await self.async_set_unique_id(
|
||||||
|
dr.format_mac(device.mac),
|
||||||
|
raise_on_progress=False,
|
||||||
|
)
|
||||||
await set_credentials(self.hass, username, password)
|
await set_credentials(self.hass, username, password)
|
||||||
if updates := self._get_config_updates(reauth_entry, host, device):
|
if updates := self._get_config_updates(reauth_entry, host, device):
|
||||||
self.hass.config_entries.async_update_entry(
|
self.hass.config_entries.async_update_entry(
|
||||||
|
@ -32,6 +32,7 @@ def mock_discovery():
|
|||||||
"homeassistant.components.tplink.Discover",
|
"homeassistant.components.tplink.Discover",
|
||||||
discover=DEFAULT,
|
discover=DEFAULT,
|
||||||
discover_single=DEFAULT,
|
discover_single=DEFAULT,
|
||||||
|
try_connect_all=DEFAULT,
|
||||||
) as mock_discovery:
|
) as mock_discovery:
|
||||||
device = _mocked_device(
|
device = _mocked_device(
|
||||||
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
|
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
|
||||||
@ -47,6 +48,7 @@ def mock_discovery():
|
|||||||
}
|
}
|
||||||
mock_discovery["discover"].return_value = devices
|
mock_discovery["discover"].return_value = devices
|
||||||
mock_discovery["discover_single"].return_value = device
|
mock_discovery["discover_single"].return_value = device
|
||||||
|
mock_discovery["try_connect_all"].return_value = device
|
||||||
mock_discovery["mock_device"] = device
|
mock_discovery["mock_device"] = device
|
||||||
yield mock_discovery
|
yield mock_discovery
|
||||||
|
|
||||||
|
@ -1023,6 +1023,30 @@ async def test_dhcp_discovery_with_ip_change(
|
|||||||
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
|
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_dhcp_discovery_discover_fail(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_discovery: AsyncMock,
|
||||||
|
mock_connect: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test dhcp discovery source cannot discover_single."""
|
||||||
|
|
||||||
|
flows = hass.config_entries.flow.async_progress()
|
||||||
|
assert len(flows) == 0
|
||||||
|
assert mock_config_entry.data[CONF_HOST] == "127.0.0.1"
|
||||||
|
|
||||||
|
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
|
||||||
|
discovery_result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={"source": config_entries.SOURCE_DHCP},
|
||||||
|
data=dhcp.DhcpServiceInfo(
|
||||||
|
ip="127.0.0.2", macaddress=DHCP_FORMATTED_MAC_ADDRESS, hostname=ALIAS
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert discovery_result["type"] is FlowResultType.ABORT
|
||||||
|
assert discovery_result["reason"] == "cannot_connect"
|
||||||
|
|
||||||
|
|
||||||
async def test_reauth(
|
async def test_reauth(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_added_config_entry: MockConfigEntry,
|
mock_added_config_entry: MockConfigEntry,
|
||||||
@ -1057,6 +1081,76 @@ async def test_reauth(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reauth_try_connect_all(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_added_config_entry: MockConfigEntry,
|
||||||
|
mock_discovery: AsyncMock,
|
||||||
|
mock_connect: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test reauth flow."""
|
||||||
|
mock_added_config_entry.async_start_reauth(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert mock_added_config_entry.state is ConfigEntryState.LOADED
|
||||||
|
flows = hass.config_entries.flow.async_progress()
|
||||||
|
assert len(flows) == 1
|
||||||
|
[result] = flows
|
||||||
|
assert result["step_id"] == "reauth_confirm"
|
||||||
|
|
||||||
|
with override_side_effect(mock_discovery["discover_single"], TimeoutError):
|
||||||
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
CONF_USERNAME: "fake_username",
|
||||||
|
CONF_PASSWORD: "fake_password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
credentials = Credentials("fake_username", "fake_password")
|
||||||
|
mock_discovery["discover_single"].assert_called_once_with(
|
||||||
|
"127.0.0.1", credentials=credentials
|
||||||
|
)
|
||||||
|
mock_discovery["try_connect_all"].assert_called_once()
|
||||||
|
assert result2["type"] is FlowResultType.ABORT
|
||||||
|
assert result2["reason"] == "reauth_successful"
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reauth_try_connect_all_fail(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_added_config_entry: MockConfigEntry,
|
||||||
|
mock_discovery: AsyncMock,
|
||||||
|
mock_connect: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test reauth flow."""
|
||||||
|
mock_added_config_entry.async_start_reauth(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert mock_added_config_entry.state is ConfigEntryState.LOADED
|
||||||
|
flows = hass.config_entries.flow.async_progress()
|
||||||
|
assert len(flows) == 1
|
||||||
|
[result] = flows
|
||||||
|
assert result["step_id"] == "reauth_confirm"
|
||||||
|
|
||||||
|
with (
|
||||||
|
override_side_effect(mock_discovery["discover_single"], TimeoutError),
|
||||||
|
override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None),
|
||||||
|
):
|
||||||
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
CONF_USERNAME: "fake_username",
|
||||||
|
CONF_PASSWORD: "fake_password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
credentials = Credentials("fake_username", "fake_password")
|
||||||
|
mock_discovery["discover_single"].assert_called_once_with(
|
||||||
|
"127.0.0.1", credentials=credentials
|
||||||
|
)
|
||||||
|
mock_discovery["try_connect_all"].assert_called_once()
|
||||||
|
assert result2["errors"] == {"base": "cannot_connect"}
|
||||||
|
|
||||||
|
|
||||||
async def test_reauth_update_with_encryption_change(
|
async def test_reauth_update_with_encryption_change(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_discovery: AsyncMock,
|
mock_discovery: AsyncMock,
|
||||||
@ -1398,7 +1492,7 @@ async def test_pick_device_errors(
|
|||||||
assert result4["context"]["unique_id"] == MAC_ADDRESS
|
assert result4["context"]["unique_id"] == MAC_ADDRESS
|
||||||
|
|
||||||
|
|
||||||
async def test_discovery_timeout_connect(
|
async def test_discovery_timeout_try_connect_all(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_discovery: AsyncMock,
|
mock_discovery: AsyncMock,
|
||||||
mock_connect: AsyncMock,
|
mock_connect: AsyncMock,
|
||||||
@ -1424,7 +1518,7 @@ async def test_discovery_timeout_connect(
|
|||||||
assert mock_connect["connect"].call_count == 1
|
assert mock_connect["connect"].call_count == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_discovery_timeout_connect_legacy_error(
|
async def test_discovery_timeout_try_connect_all_needs_creds(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_discovery: AsyncMock,
|
mock_discovery: AsyncMock,
|
||||||
mock_connect: AsyncMock,
|
mock_connect: AsyncMock,
|
||||||
@ -1446,8 +1540,57 @@ async def test_discovery_timeout_connect_legacy_error(
|
|||||||
result["flow_id"], {CONF_HOST: IP_ADDRESS}
|
result["flow_id"], {CONF_HOST: IP_ADDRESS}
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
assert result2["step_id"] == "user_auth_confirm"
|
||||||
assert result2["type"] is FlowResultType.FORM
|
assert result2["type"] is FlowResultType.FORM
|
||||||
assert result2["errors"] == {"base": "cannot_connect"}
|
|
||||||
|
result3 = await hass.config_entries.flow.async_configure(
|
||||||
|
result2["flow_id"],
|
||||||
|
user_input={
|
||||||
|
CONF_USERNAME: "fake_username",
|
||||||
|
CONF_PASSWORD: "fake_password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result3["context"]["unique_id"] == MAC_ADDRESS
|
||||||
|
assert mock_connect["connect"].call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_discovery_timeout_try_connect_all_fail(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_discovery: AsyncMock,
|
||||||
|
mock_connect: AsyncMock,
|
||||||
|
mock_init,
|
||||||
|
) -> None:
|
||||||
|
"""Test discovery tries legacy connect on timeout."""
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
mock_discovery["discover_single"].side_effect = TimeoutError
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "user"
|
||||||
|
assert not result["errors"]
|
||||||
|
assert mock_connect["connect"].call_count == 0
|
||||||
|
|
||||||
|
with override_side_effect(mock_connect["connect"], KasaException):
|
||||||
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {CONF_HOST: IP_ADDRESS}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result2["step_id"] == "user_auth_confirm"
|
||||||
|
assert result2["type"] is FlowResultType.FORM
|
||||||
|
|
||||||
|
with override_side_effect(mock_discovery["try_connect_all"], lambda *_, **__: None):
|
||||||
|
result3 = await hass.config_entries.flow.async_configure(
|
||||||
|
result2["flow_id"],
|
||||||
|
user_input={
|
||||||
|
CONF_USERNAME: "fake_username",
|
||||||
|
CONF_PASSWORD: "fake_password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result3["errors"] == {"base": "cannot_connect"}
|
||||||
assert mock_connect["connect"].call_count == 1
|
assert mock_connect["connect"].call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user