diff --git a/homeassistant/components/tplink/__init__.py b/homeassistant/components/tplink/__init__.py index ceeb1120ed8..ee1d90e70b4 100644 --- a/homeassistant/components/tplink/__init__.py +++ b/homeassistant/components/tplink/__init__.py @@ -31,6 +31,7 @@ from homeassistant.const import ( CONF_MAC, CONF_MODEL, CONF_PASSWORD, + CONF_PORT, CONF_USERNAME, ) from homeassistant.core import HomeAssistant, callback @@ -141,6 +142,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH) entry_use_http = entry.data.get(CONF_USES_HTTP, False) entry_aes_keys = entry.data.get(CONF_AES_KEYS) + port_override = entry.data.get(CONF_PORT) conn_params: Device.ConnectionParameters | None = None if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS): @@ -157,6 +159,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo timeout=CONNECT_TIMEOUT, http_client=client, aes_keys=entry_aes_keys, + port_override=port_override, ) if conn_params: config.connection_type = conn_params diff --git a/homeassistant/components/tplink/config_flow.py b/homeassistant/components/tplink/config_flow.py index a9f665e12fd..63f1b4e125b 100644 --- a/homeassistant/components/tplink/config_flow.py +++ b/homeassistant/components/tplink/config_flow.py @@ -32,6 +32,7 @@ from homeassistant.const import ( CONF_MAC, CONF_MODEL, CONF_PASSWORD, + CONF_PORT, CONF_USERNAME, ) from homeassistant.core import callback @@ -69,6 +70,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION host: str | None = None + port: int | None = None def __init__(self) -> None: """Initialize the config flow.""" @@ -260,6 +262,26 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): step_id="discovery_confirm", description_placeholders=placeholders ) + @staticmethod + def _async_get_host_port(host_str: str) -> tuple[str, int | None]: + """Parse the host string for host and port.""" + if "[" in host_str: + _, _, bracketed = host_str.partition("[") + host, _, port_str = bracketed.partition("]") + _, _, port_str = port_str.partition(":") + else: + host, _, port_str = host_str.partition(":") + + if not port_str: + return host, None + + try: + port = int(port_str) + except ValueError: + return host, None + + return host, port + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: @@ -270,14 +292,29 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): if user_input is not None: if not (host := user_input[CONF_HOST]): return await self.async_step_pick_device() - self._async_abort_entries_match({CONF_HOST: host}) + + host, port = self._async_get_host_port(host) + + match_dict = {CONF_HOST: host} + if port: + self.port = port + match_dict[CONF_PORT] = port + self._async_abort_entries_match(match_dict) + self.host = host credentials = await get_credentials(self.hass) try: device = await self._async_try_discover_and_update( - host, credentials, raise_on_progress=False, raise_on_timeout=False + host, + credentials, + raise_on_progress=False, + raise_on_timeout=False, + port=port, ) or await self._async_try_connect_all( - host, credentials=credentials, raise_on_progress=False + host, + credentials=credentials, + raise_on_progress=False, + port=port, ) except AuthenticationError: return await self.async_step_user_auth_confirm() @@ -318,7 +355,10 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): ) else: device = await self._async_try_connect_all( - self.host, credentials=credentials, raise_on_progress=False + self.host, + credentials=credentials, + raise_on_progress=False, + port=self.port, ) except AuthenticationError as ex: errors[CONF_PASSWORD] = "invalid_auth" @@ -420,6 +460,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): data[CONF_AES_KEYS] = device.config.aes_keys if device.credentials_hash: data[CONF_CREDENTIALS_HASH] = device.credentials_hash + if port := device.config.port_override: + data[CONF_PORT] = port return self.async_create_entry( title=f"{device.alias} {device.model}", data=data, @@ -430,6 +472,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): host: str, credentials: Credentials | None, raise_on_progress: bool, + *, + port: int | None = None, ) -> Device | None: """Try to connect to the device speculatively. @@ -441,12 +485,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): host, credentials=credentials, http_client=create_async_tplink_clientsession(self.hass), + port=port, ) else: # This will just try the legacy protocol that doesn't require auth # and doesn't use http try: - device = await Device.connect(config=DeviceConfig(host)) + device = await Device.connect( + config=DeviceConfig(host, port_override=port) + ) except Exception: # noqa: BLE001 return None if device: @@ -462,6 +509,8 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): credentials: Credentials | None, raise_on_progress: bool, raise_on_timeout: bool, + *, + port: int | None = None, ) -> Device | None: """Try to discover the device and call update. @@ -470,7 +519,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): self._discovered_device = None try: self._discovered_device = await Discover.discover_single( - host, credentials=credentials + host, + credentials=credentials, + port=port, ) except TimeoutError as ex: if raise_on_timeout: @@ -526,6 +577,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): reauth_entry = self._get_reauth_entry() entry_data = reauth_entry.data host = entry_data[CONF_HOST] + port = entry_data.get(CONF_PORT) if user_input: username = user_input[CONF_USERNAME] @@ -537,8 +589,12 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): credentials=credentials, raise_on_progress=False, raise_on_timeout=False, + port=port, ) or await self._async_try_connect_all( - host, credentials=credentials, raise_on_progress=False + host, + credentials=credentials, + raise_on_progress=False, + port=port, ) except AuthenticationError as ex: errors[CONF_PASSWORD] = "invalid_auth" diff --git a/tests/components/tplink/conftest.py b/tests/components/tplink/conftest.py index 78cc9304bf7..25a4bd20270 100644 --- a/tests/components/tplink/conftest.py +++ b/tests/components/tplink/conftest.py @@ -37,7 +37,7 @@ def mock_discovery(): device = _mocked_device( device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), credentials_hash=CREDENTIALS_HASH_KLAP, - alias=None, + alias="My Bulb", ) devices = { "127.0.0.1": _mocked_device( diff --git a/tests/components/tplink/test_config_flow.py b/tests/components/tplink/test_config_flow.py index 12a5741058c..2697696c667 100644 --- a/tests/components/tplink/test_config_flow.py +++ b/tests/components/tplink/test_config_flow.py @@ -2,7 +2,7 @@ from contextlib import contextmanager import logging -from unittest.mock import AsyncMock, patch +from unittest.mock import ANY, AsyncMock, patch from kasa import TimeoutError import pytest @@ -30,6 +30,7 @@ from homeassistant.const import ( CONF_HOST, CONF_MAC, CONF_PASSWORD, + CONF_PORT, CONF_USERNAME, ) from homeassistant.core import HomeAssistant @@ -665,6 +666,93 @@ async def test_manual_auth_errors( await hass.async_block_till_done() +@pytest.mark.parametrize( + ("host_str", "host", "port"), + [ + (f"{IP_ADDRESS}:1234", IP_ADDRESS, 1234), + ("[2001:db8:0::1]:4321", "2001:db8:0::1", 4321), + ], +) +async def test_manual_port_override( + hass: HomeAssistant, + mock_connect: AsyncMock, + mock_discovery: AsyncMock, + host_str, + host, + port, +) -> None: + """Test manually setup.""" + mock_discovery["mock_device"].config.port_override = port + mock_discovery["mock_device"].host = host + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + assert not result["errors"] + + # side_effects to cause auth confirm as the port override usually only + # works with direct connections. + mock_discovery["discover_single"].side_effect = TimeoutError + mock_connect["connect"].side_effect = AuthenticationError + + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_HOST: host_str} + ) + await hass.async_block_till_done() + + assert result2["type"] is FlowResultType.FORM + assert result2["step_id"] == "user_auth_confirm" + assert not result2["errors"] + + creds = Credentials("fake_username", "fake_password") + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) + await hass.async_block_till_done() + mock_discovery["try_connect_all"].assert_called_once_with( + host, credentials=creds, port=port, http_client=ANY + ) + assert result3["type"] is FlowResultType.CREATE_ENTRY + assert result3["title"] == DEFAULT_ENTRY_TITLE + assert result3["data"] == { + **CREATE_ENTRY_DATA_KLAP, + CONF_PORT: port, + CONF_HOST: host, + } + assert result3["context"]["unique_id"] == MAC_ADDRESS + + +async def test_manual_port_override_invalid( + hass: HomeAssistant, mock_connect: AsyncMock, mock_discovery: AsyncMock +) -> None: + """Test manually setup.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + assert not result["errors"] + + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_HOST: f"{IP_ADDRESS}:foo"} + ) + await hass.async_block_till_done() + + mock_discovery["discover_single"].assert_called_once_with( + "127.0.0.1", credentials=None, port=None + ) + + assert result2["type"] is FlowResultType.CREATE_ENTRY + assert result2["title"] == DEFAULT_ENTRY_TITLE + assert result2["data"] == CREATE_ENTRY_DATA_KLAP + assert result2["context"]["unique_id"] == MAC_ADDRESS + + async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None: """Test we get the form with discovery and abort for dhcp source when we get both.""" @@ -1072,7 +1160,7 @@ async def test_reauth( ) credentials = Credentials("fake_username", "fake_password") mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.1", credentials=credentials + "127.0.0.1", credentials=credentials, port=None ) mock_discovery["mock_device"].update.assert_called_once_with() assert result2["type"] is FlowResultType.ABORT @@ -1107,7 +1195,7 @@ async def test_reauth_try_connect_all( ) credentials = Credentials("fake_username", "fake_password") mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.1", credentials=credentials + "127.0.0.1", credentials=credentials, port=None ) mock_discovery["try_connect_all"].assert_called_once() assert result2["type"] is FlowResultType.ABORT @@ -1145,7 +1233,7 @@ async def test_reauth_try_connect_all_fail( ) credentials = Credentials("fake_username", "fake_password") mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.1", credentials=credentials + "127.0.0.1", credentials=credentials, port=None ) mock_discovery["try_connect_all"].assert_called_once() assert result2["errors"] == {"base": "cannot_connect"} @@ -1214,7 +1302,7 @@ async def test_reauth_update_with_encryption_change( assert "Connection type changed for 127.0.0.2" in caplog.text credentials = Credentials("fake_username", "fake_password") mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.2", credentials=credentials + "127.0.0.2", credentials=credentials, port=None ) mock_discovery["mock_device"].update.assert_called_once_with() assert result2["type"] is FlowResultType.ABORT @@ -1416,7 +1504,7 @@ async def test_reauth_errors( credentials = Credentials("fake_username", "fake_password") mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.1", credentials=credentials + "127.0.0.1", credentials=credentials, port=None ) mock_discovery["mock_device"].update.assert_called_once_with() assert result2["type"] is FlowResultType.FORM @@ -1434,7 +1522,7 @@ async def test_reauth_errors( ) mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.1", credentials=credentials + "127.0.0.1", credentials=credentials, port=None ) mock_discovery["mock_device"].update.assert_called_once_with() @@ -1643,7 +1731,7 @@ async def test_reauth_update_other_flows( ) credentials = Credentials("fake_username", "fake_password") mock_discovery["discover_single"].assert_called_once_with( - "127.0.0.1", credentials=credentials + "127.0.0.1", credentials=credentials, port=None ) mock_discovery["mock_device"].update.assert_called_once_with() assert result2["type"] is FlowResultType.ABORT