Enable overriding connection port for tplink devices (#129619)

Enable setting a port override during manual config entry setup.

The feature will be undocumented as it's quite a specialized use case generally used for testing purposes.
This commit is contained in:
Steven B. 2024-11-08 13:41:00 +00:00 committed by GitHub
parent f49547d598
commit 03c3d09583
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 163 additions and 16 deletions

View File

@ -31,6 +31,7 @@ from homeassistant.const import (
CONF_MAC, CONF_MAC,
CONF_MODEL, CONF_MODEL,
CONF_PASSWORD, CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -141,6 +142,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH) entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH)
entry_use_http = entry.data.get(CONF_USES_HTTP, False) entry_use_http = entry.data.get(CONF_USES_HTTP, False)
entry_aes_keys = entry.data.get(CONF_AES_KEYS) entry_aes_keys = entry.data.get(CONF_AES_KEYS)
port_override = entry.data.get(CONF_PORT)
conn_params: Device.ConnectionParameters | None = None conn_params: Device.ConnectionParameters | None = None
if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS): if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS):
@ -157,6 +159,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
timeout=CONNECT_TIMEOUT, timeout=CONNECT_TIMEOUT,
http_client=client, http_client=client,
aes_keys=entry_aes_keys, aes_keys=entry_aes_keys,
port_override=port_override,
) )
if conn_params: if conn_params:
config.connection_type = conn_params config.connection_type = conn_params

View File

@ -32,6 +32,7 @@ from homeassistant.const import (
CONF_MAC, CONF_MAC,
CONF_MODEL, CONF_MODEL,
CONF_PASSWORD, CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import callback from homeassistant.core import callback
@ -69,6 +70,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
host: str | None = None host: str | None = None
port: int | None = None
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the config flow.""" """Initialize the config flow."""
@ -260,6 +262,26 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
step_id="discovery_confirm", description_placeholders=placeholders step_id="discovery_confirm", description_placeholders=placeholders
) )
@staticmethod
def _async_get_host_port(host_str: str) -> tuple[str, int | None]:
"""Parse the host string for host and port."""
if "[" in host_str:
_, _, bracketed = host_str.partition("[")
host, _, port_str = bracketed.partition("]")
_, _, port_str = port_str.partition(":")
else:
host, _, port_str = host_str.partition(":")
if not port_str:
return host, None
try:
port = int(port_str)
except ValueError:
return host, None
return host, port
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
@ -270,14 +292,29 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
if not (host := user_input[CONF_HOST]): if not (host := user_input[CONF_HOST]):
return await self.async_step_pick_device() return await self.async_step_pick_device()
self._async_abort_entries_match({CONF_HOST: host})
host, port = self._async_get_host_port(host)
match_dict = {CONF_HOST: host}
if port:
self.port = port
match_dict[CONF_PORT] = port
self._async_abort_entries_match(match_dict)
self.host = host self.host = host
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:
device = await self._async_try_discover_and_update( device = await self._async_try_discover_and_update(
host, credentials, raise_on_progress=False, raise_on_timeout=False host,
credentials,
raise_on_progress=False,
raise_on_timeout=False,
port=port,
) or await self._async_try_connect_all( ) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False host,
credentials=credentials,
raise_on_progress=False,
port=port,
) )
except AuthenticationError: except AuthenticationError:
return await self.async_step_user_auth_confirm() return await self.async_step_user_auth_confirm()
@ -318,7 +355,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
) )
else: else:
device = await self._async_try_connect_all( device = await self._async_try_connect_all(
self.host, credentials=credentials, raise_on_progress=False self.host,
credentials=credentials,
raise_on_progress=False,
port=self.port,
) )
except AuthenticationError as ex: except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth" errors[CONF_PASSWORD] = "invalid_auth"
@ -420,6 +460,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
data[CONF_AES_KEYS] = device.config.aes_keys data[CONF_AES_KEYS] = device.config.aes_keys
if device.credentials_hash: if device.credentials_hash:
data[CONF_CREDENTIALS_HASH] = device.credentials_hash data[CONF_CREDENTIALS_HASH] = device.credentials_hash
if port := device.config.port_override:
data[CONF_PORT] = port
return self.async_create_entry( return self.async_create_entry(
title=f"{device.alias} {device.model}", title=f"{device.alias} {device.model}",
data=data, data=data,
@ -430,6 +472,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
host: str, host: str,
credentials: Credentials | None, credentials: Credentials | None,
raise_on_progress: bool, raise_on_progress: bool,
*,
port: int | None = None,
) -> Device | None: ) -> Device | None:
"""Try to connect to the device speculatively. """Try to connect to the device speculatively.
@ -441,12 +485,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
host, host,
credentials=credentials, credentials=credentials,
http_client=create_async_tplink_clientsession(self.hass), http_client=create_async_tplink_clientsession(self.hass),
port=port,
) )
else: else:
# This will just try the legacy protocol that doesn't require auth # This will just try the legacy protocol that doesn't require auth
# and doesn't use http # and doesn't use http
try: try:
device = await Device.connect(config=DeviceConfig(host)) device = await Device.connect(
config=DeviceConfig(host, port_override=port)
)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
return None return None
if device: if device:
@ -462,6 +509,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
credentials: Credentials | None, credentials: Credentials | None,
raise_on_progress: bool, raise_on_progress: bool,
raise_on_timeout: bool, raise_on_timeout: bool,
*,
port: int | None = None,
) -> Device | None: ) -> Device | None:
"""Try to discover the device and call update. """Try to discover the device and call update.
@ -470,7 +519,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
self._discovered_device = None self._discovered_device = None
try: try:
self._discovered_device = await Discover.discover_single( self._discovered_device = await Discover.discover_single(
host, credentials=credentials host,
credentials=credentials,
port=port,
) )
except TimeoutError as ex: except TimeoutError as ex:
if raise_on_timeout: if raise_on_timeout:
@ -526,6 +577,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
reauth_entry = self._get_reauth_entry() reauth_entry = self._get_reauth_entry()
entry_data = reauth_entry.data entry_data = reauth_entry.data
host = entry_data[CONF_HOST] host = entry_data[CONF_HOST]
port = entry_data.get(CONF_PORT)
if user_input: if user_input:
username = user_input[CONF_USERNAME] username = user_input[CONF_USERNAME]
@ -537,8 +589,12 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
credentials=credentials, credentials=credentials,
raise_on_progress=False, raise_on_progress=False,
raise_on_timeout=False, raise_on_timeout=False,
port=port,
) or await self._async_try_connect_all( ) or await self._async_try_connect_all(
host, credentials=credentials, raise_on_progress=False host,
credentials=credentials,
raise_on_progress=False,
port=port,
) )
except AuthenticationError as ex: except AuthenticationError as ex:
errors[CONF_PASSWORD] = "invalid_auth" errors[CONF_PASSWORD] = "invalid_auth"

View File

@ -37,7 +37,7 @@ def mock_discovery():
device = _mocked_device( device = _mocked_device(
device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
credentials_hash=CREDENTIALS_HASH_KLAP, credentials_hash=CREDENTIALS_HASH_KLAP,
alias=None, alias="My Bulb",
) )
devices = { devices = {
"127.0.0.1": _mocked_device( "127.0.0.1": _mocked_device(

View File

@ -2,7 +2,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
from unittest.mock import AsyncMock, patch from unittest.mock import ANY, AsyncMock, patch
from kasa import TimeoutError from kasa import TimeoutError
import pytest import pytest
@ -30,6 +30,7 @@ from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_MAC, CONF_MAC,
CONF_PASSWORD, CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -665,6 +666,93 @@ async def test_manual_auth_errors(
await hass.async_block_till_done() await hass.async_block_till_done()
@pytest.mark.parametrize(
("host_str", "host", "port"),
[
(f"{IP_ADDRESS}:1234", IP_ADDRESS, 1234),
("[2001:db8:0::1]:4321", "2001:db8:0::1", 4321),
],
)
async def test_manual_port_override(
hass: HomeAssistant,
mock_connect: AsyncMock,
mock_discovery: AsyncMock,
host_str,
host,
port,
) -> None:
"""Test manually setup."""
mock_discovery["mock_device"].config.port_override = port
mock_discovery["mock_device"].host = host
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
# side_effects to cause auth confirm as the port override usually only
# works with direct connections.
mock_discovery["discover_single"].side_effect = TimeoutError
mock_connect["connect"].side_effect = AuthenticationError
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_HOST: host_str}
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.FORM
assert result2["step_id"] == "user_auth_confirm"
assert not result2["errors"]
creds = Credentials("fake_username", "fake_password")
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
user_input={
CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password",
},
)
await hass.async_block_till_done()
mock_discovery["try_connect_all"].assert_called_once_with(
host, credentials=creds, port=port, http_client=ANY
)
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["title"] == DEFAULT_ENTRY_TITLE
assert result3["data"] == {
**CREATE_ENTRY_DATA_KLAP,
CONF_PORT: port,
CONF_HOST: host,
}
assert result3["context"]["unique_id"] == MAC_ADDRESS
async def test_manual_port_override_invalid(
hass: HomeAssistant, mock_connect: AsyncMock, mock_discovery: AsyncMock
) -> None:
"""Test manually setup."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {CONF_HOST: f"{IP_ADDRESS}:foo"}
)
await hass.async_block_till_done()
mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=None, port=None
)
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == DEFAULT_ENTRY_TITLE
assert result2["data"] == CREATE_ENTRY_DATA_KLAP
assert result2["context"]["unique_id"] == MAC_ADDRESS
async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None: async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
"""Test we get the form with discovery and abort for dhcp source when we get both.""" """Test we get the form with discovery and abort for dhcp source when we get both."""
@ -1072,7 +1160,7 @@ async def test_reauth(
) )
credentials = Credentials("fake_username", "fake_password") credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials "127.0.0.1", credentials=credentials, port=None
) )
mock_discovery["mock_device"].update.assert_called_once_with() mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT
@ -1107,7 +1195,7 @@ async def test_reauth_try_connect_all(
) )
credentials = Credentials("fake_username", "fake_password") credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials "127.0.0.1", credentials=credentials, port=None
) )
mock_discovery["try_connect_all"].assert_called_once() mock_discovery["try_connect_all"].assert_called_once()
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT
@ -1145,7 +1233,7 @@ async def test_reauth_try_connect_all_fail(
) )
credentials = Credentials("fake_username", "fake_password") credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials "127.0.0.1", credentials=credentials, port=None
) )
mock_discovery["try_connect_all"].assert_called_once() mock_discovery["try_connect_all"].assert_called_once()
assert result2["errors"] == {"base": "cannot_connect"} assert result2["errors"] == {"base": "cannot_connect"}
@ -1214,7 +1302,7 @@ async def test_reauth_update_with_encryption_change(
assert "Connection type changed for 127.0.0.2" in caplog.text assert "Connection type changed for 127.0.0.2" in caplog.text
credentials = Credentials("fake_username", "fake_password") credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.2", credentials=credentials "127.0.0.2", credentials=credentials, port=None
) )
mock_discovery["mock_device"].update.assert_called_once_with() mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT
@ -1416,7 +1504,7 @@ async def test_reauth_errors(
credentials = Credentials("fake_username", "fake_password") credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials "127.0.0.1", credentials=credentials, port=None
) )
mock_discovery["mock_device"].update.assert_called_once_with() mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
@ -1434,7 +1522,7 @@ async def test_reauth_errors(
) )
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials "127.0.0.1", credentials=credentials, port=None
) )
mock_discovery["mock_device"].update.assert_called_once_with() mock_discovery["mock_device"].update.assert_called_once_with()
@ -1643,7 +1731,7 @@ async def test_reauth_update_other_flows(
) )
credentials = Credentials("fake_username", "fake_password") credentials = Credentials("fake_username", "fake_password")
mock_discovery["discover_single"].assert_called_once_with( mock_discovery["discover_single"].assert_called_once_with(
"127.0.0.1", credentials=credentials "127.0.0.1", credentials=credentials, port=None
) )
mock_discovery["mock_device"].update.assert_called_once_with() mock_discovery["mock_device"].update.assert_called_once_with()
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT