diff --git a/homeassistant/components/tplink/__init__.py b/homeassistant/components/tplink/__init__.py index 83cfc733716..ceeb1120ed8 100644 --- a/homeassistant/components/tplink/__init__.py +++ b/homeassistant/components/tplink/__init__.py @@ -26,6 +26,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_ALIAS, CONF_AUTHENTICATION, + CONF_DEVICE, CONF_HOST, CONF_MAC, CONF_MODEL, @@ -44,8 +45,12 @@ from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType from .const import ( + CONF_AES_KEYS, + CONF_CONFIG_ENTRY_MINOR_VERSION, + CONF_CONNECTION_PARAMETERS, CONF_CREDENTIALS_HASH, CONF_DEVICE_CONFIG, + CONF_USES_HTTP, CONNECT_TIMEOUT, DISCOVERY_TIMEOUT, DOMAIN, @@ -85,9 +90,7 @@ def async_trigger_discovery( CONF_ALIAS: device.alias or mac_alias(device.mac), CONF_HOST: device.host, CONF_MAC: formatted_mac, - CONF_DEVICE_CONFIG: device.config.to_dict( - exclude_credentials=True, - ), + CONF_DEVICE: device, }, ) @@ -136,25 +139,27 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo host: str = entry.data[CONF_HOST] credentials = await get_credentials(hass) entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH) + entry_use_http = entry.data.get(CONF_USES_HTTP, False) + entry_aes_keys = entry.data.get(CONF_AES_KEYS) - config: DeviceConfig | None = None - if config_dict := entry.data.get(CONF_DEVICE_CONFIG): + conn_params: Device.ConnectionParameters | None = None + if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS): try: - config = DeviceConfig.from_dict(config_dict) + conn_params = Device.ConnectionParameters.from_dict(conn_params_dict) except KasaException: _LOGGER.warning( - "Invalid connection type dict for %s: %s", host, config_dict + "Invalid connection parameters dict for %s: %s", host, conn_params_dict ) - if not config: - config = DeviceConfig(host) - else: - config.host = host - - config.timeout = CONNECT_TIMEOUT - if config.uses_http is True: - config.http_client = create_async_tplink_clientsession(hass) - + client = create_async_tplink_clientsession(hass) if entry_use_http else None + config = DeviceConfig( + host, + timeout=CONNECT_TIMEOUT, + http_client=client, + aes_keys=entry_aes_keys, + ) + if conn_params: + config.connection_type = conn_params # If we have in memory credentials use them otherwise check for credentials_hash if credentials: config.credentials = credentials @@ -173,14 +178,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo raise ConfigEntryNotReady from ex device_credentials_hash = device.credentials_hash - device_config_dict = device.config.to_dict(exclude_credentials=True) - # Do not store the credentials hash inside the device_config - device_config_dict.pop(CONF_CREDENTIALS_HASH, None) + + # We not need to update the connection parameters or the use_http here + # because if they were wrong we would have failed to connect. + # Discovery will update those if necessary. updates: dict[str, Any] = {} if device_credentials_hash and device_credentials_hash != entry_credentials_hash: updates[CONF_CREDENTIALS_HASH] = device_credentials_hash - if device_config_dict != config_dict: - updates[CONF_DEVICE_CONFIG] = device_config_dict + if entry_aes_keys != device.config.aes_keys: + updates[CONF_AES_KEYS] = device.config.aes_keys if entry.data.get(CONF_ALIAS) != device.alias: updates[CONF_ALIAS] = device.alias if entry.data.get(CONF_MODEL) != device.model: @@ -307,12 +313,20 @@ def _device_id_is_mac_or_none(mac: str, device_ids: Iterable[str]) -> str | None async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Migrate old entry.""" - version = config_entry.version - minor_version = config_entry.minor_version + entry_version = config_entry.version + entry_minor_version = config_entry.minor_version + # having a condition to check for the current version allows + # tests to be written per migration step. + config_flow_minor_version = CONF_CONFIG_ENTRY_MINOR_VERSION - _LOGGER.debug("Migrating from version %s.%s", version, minor_version) - - if version == 1 and minor_version < 3: + new_minor_version = 3 + if ( + entry_version == 1 + and entry_minor_version < new_minor_version <= config_flow_minor_version + ): + _LOGGER.debug( + "Migrating from version %s.%s", entry_version, entry_minor_version + ) # Previously entities on child devices added themselves to the parent # device and set their device id as identifiers along with mac # as a connection which creates a single device entry linked by all @@ -359,12 +373,19 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> new_identifiers, ) - minor_version = 3 - hass.config_entries.async_update_entry(config_entry, minor_version=3) + hass.config_entries.async_update_entry( + config_entry, minor_version=new_minor_version + ) - _LOGGER.debug("Migration to version %s.%s complete", version, minor_version) + _LOGGER.debug( + "Migration to version %s.%s complete", entry_version, new_minor_version + ) - if version == 1 and minor_version == 3: + new_minor_version = 4 + if ( + entry_version == 1 + and entry_minor_version < new_minor_version <= config_flow_minor_version + ): # credentials_hash stored in the device_config should be moved to data. updates: dict[str, Any] = {} if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG): @@ -372,15 +393,44 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> if credentials_hash := config_dict.pop(CONF_CREDENTIALS_HASH, None): updates[CONF_CREDENTIALS_HASH] = credentials_hash updates[CONF_DEVICE_CONFIG] = config_dict - minor_version = 4 hass.config_entries.async_update_entry( config_entry, data={ **config_entry.data, **updates, }, - minor_version=minor_version, + minor_version=new_minor_version, + ) + _LOGGER.debug( + "Migration to version %s.%s complete", entry_version, new_minor_version ) - _LOGGER.debug("Migration to version %s.%s complete", version, minor_version) + new_minor_version = 5 + if ( + entry_version == 1 + and entry_minor_version < new_minor_version <= config_flow_minor_version + ): + # complete device config no longer to be stored, only required + # attributes like connection parameters and aes_keys + updates = {} + entry_data = { + k: v for k, v in config_entry.data.items() if k != CONF_DEVICE_CONFIG + } + if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG): + assert isinstance(config_dict, dict) + if connection_parameters := config_dict.get("connection_type"): + updates[CONF_CONNECTION_PARAMETERS] = connection_parameters + if (use_http := config_dict.get(CONF_USES_HTTP)) is not None: + updates[CONF_USES_HTTP] = use_http + hass.config_entries.async_update_entry( + config_entry, + data={ + **entry_data, + **updates, + }, + minor_version=new_minor_version, + ) + _LOGGER.debug( + "Migration to version %s.%s complete", entry_version, new_minor_version + ) return True diff --git a/homeassistant/components/tplink/config_flow.py b/homeassistant/components/tplink/config_flow.py index 1c02466aef1..03234d545b5 100644 --- a/homeassistant/components/tplink/config_flow.py +++ b/homeassistant/components/tplink/config_flow.py @@ -46,9 +46,11 @@ from . import ( set_credentials, ) from .const import ( - CONF_CONNECTION_TYPE, + CONF_AES_KEYS, + CONF_CONFIG_ENTRY_MINOR_VERSION, + CONF_CONNECTION_PARAMETERS, CONF_CREDENTIALS_HASH, - CONF_DEVICE_CONFIG, + CONF_USES_HTTP, CONNECT_TIMEOUT, DOMAIN, ) @@ -64,7 +66,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for tplink.""" VERSION = 1 - MINOR_VERSION = 4 + MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION reauth_entry: ConfigEntry | None = None def __init__(self) -> None: @@ -87,38 +89,43 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): return await self._async_handle_discovery( discovery_info[CONF_HOST], discovery_info[CONF_MAC], - discovery_info[CONF_DEVICE_CONFIG], + discovery_info[CONF_DEVICE], ) @callback def _get_config_updates( - self, entry: ConfigEntry, host: str, config: dict + self, entry: ConfigEntry, host: str, device: Device | None ) -> dict | None: """Return updates if the host or device config has changed.""" entry_data = entry.data - entry_config_dict = entry_data.get(CONF_DEVICE_CONFIG) - if entry_config_dict == config and entry_data[CONF_HOST] == host: + updates: dict[str, Any] = {} + new_connection_params = False + if entry_data[CONF_HOST] != host: + updates[CONF_HOST] = host + if device: + device_conn_params_dict = device.config.connection_type.to_dict() + entry_conn_params_dict = entry_data.get(CONF_CONNECTION_PARAMETERS) + if device_conn_params_dict != entry_conn_params_dict: + new_connection_params = True + updates[CONF_CONNECTION_PARAMETERS] = device_conn_params_dict + updates[CONF_USES_HTTP] = device.config.uses_http + if not updates: return None - updates = {**entry.data, CONF_DEVICE_CONFIG: config, CONF_HOST: host} + updates = {**entry.data, **updates} # If the connection parameters have changed the credentials_hash will be invalid. - if ( - entry_config_dict - and isinstance(entry_config_dict, dict) - and entry_config_dict.get(CONF_CONNECTION_TYPE) - != config.get(CONF_CONNECTION_TYPE) - ): + if new_connection_params: updates.pop(CONF_CREDENTIALS_HASH, None) _LOGGER.debug( "Connection type changed for %s from %s to: %s", host, - entry_config_dict.get(CONF_CONNECTION_TYPE), - config.get(CONF_CONNECTION_TYPE), + entry_conn_params_dict, + device_conn_params_dict, ) return updates @callback def _update_config_if_entry_in_setup_error( - self, entry: ConfigEntry, host: str, config: dict + self, entry: ConfigEntry, host: str, device: Device | None ) -> ConfigFlowResult | None: """If discovery encounters a device that is in SETUP_ERROR or SETUP_RETRY update the device config.""" if entry.state not in ( @@ -126,7 +133,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): ConfigEntryState.SETUP_RETRY, ): return None - if updates := self._get_config_updates(entry, host, config): + if updates := self._get_config_updates(entry, host, device): return self.async_update_reload_and_abort( entry, data=updates, @@ -135,19 +142,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): return None async def _async_handle_discovery( - self, host: str, formatted_mac: str, config: dict | None = None + self, host: str, formatted_mac: str, device: Device | None = None ) -> ConfigFlowResult: """Handle any discovery.""" current_entry = await self.async_set_unique_id( formatted_mac, raise_on_progress=False ) - if ( - config - and current_entry - and ( - result := self._update_config_if_entry_in_setup_error( - current_entry, host, config - ) + if current_entry and ( + result := self._update_config_if_entry_in_setup_error( + current_entry, host, device ) ): return result @@ -159,9 +162,13 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_abort(reason="already_in_progress") credentials = await get_credentials(self.hass) try: - await self._async_try_discover_and_update( - host, credentials, raise_on_progress=True - ) + if device: + self._discovered_device = device + await self._async_try_connect(device, credentials) + else: + await self._async_try_discover_and_update( + host, credentials, raise_on_progress=True + ) except AuthenticationError: return await self.async_step_discovery_auth_confirm() except KasaException: @@ -381,14 +388,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): # This is only ever called after a successful device update so we know that # the credential_hash is correct and should be saved. self._abort_if_unique_id_configured(updates={CONF_HOST: device.host}) - data = { + data: dict[str, Any] = { CONF_HOST: device.host, CONF_ALIAS: device.alias, CONF_MODEL: device.model, - CONF_DEVICE_CONFIG: device.config.to_dict( - exclude_credentials=True, - ), + CONF_CONNECTION_PARAMETERS: device.config.connection_type.to_dict(), + CONF_USES_HTTP: device.config.uses_http, } + if device.config.aes_keys: + data[CONF_AES_KEYS] = device.config.aes_keys if device.credentials_hash: data[CONF_CREDENTIALS_HASH] = device.credentials_hash return self.async_create_entry( @@ -494,8 +502,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): placeholders["error"] = str(ex) else: await set_credentials(self.hass, username, password) - config = device.config.to_dict(exclude_credentials=True) - if updates := self._get_config_updates(reauth_entry, host, config): + if updates := self._get_config_updates(reauth_entry, host, device): self.hass.config_entries.async_update_entry( reauth_entry, data=updates ) diff --git a/homeassistant/components/tplink/const.py b/homeassistant/components/tplink/const.py index babd92e2c34..91085edb5a2 100644 --- a/homeassistant/components/tplink/const.py +++ b/homeassistant/components/tplink/const.py @@ -21,7 +21,11 @@ ATTR_TOTAL_ENERGY_KWH: Final = "total_energy_kwh" CONF_DEVICE_CONFIG: Final = "device_config" CONF_CREDENTIALS_HASH: Final = "credentials_hash" -CONF_CONNECTION_TYPE: Final = "connection_type" +CONF_CONNECTION_PARAMETERS: Final = "connection_parameters" +CONF_USES_HTTP: Final = "uses_http" +CONF_AES_KEYS: Final = "aes_keys" + +CONF_CONFIG_ENTRY_MINOR_VERSION: Final = 5 PLATFORMS: Final = [ Platform.BINARY_SENSOR, diff --git a/tests/components/tplink/__init__.py b/tests/components/tplink/__init__.py index c63ca9139f1..93c3a35a2e9 100644 --- a/tests/components/tplink/__init__.py +++ b/tests/components/tplink/__init__.py @@ -21,11 +21,13 @@ from kasa.protocol import BaseProtocol from syrupy import SnapshotAssertion from homeassistant.components.tplink import ( + CONF_AES_KEYS, CONF_ALIAS, + CONF_CONNECTION_PARAMETERS, CONF_CREDENTIALS_HASH, - CONF_DEVICE_CONFIG, CONF_HOST, CONF_MODEL, + CONF_USES_HTTP, Credentials, ) from homeassistant.components.tplink.const import DOMAIN @@ -54,35 +56,42 @@ DHCP_FORMATTED_MAC_ADDRESS = MAC_ADDRESS.replace(":", "") MAC_ADDRESS2 = "11:22:33:44:55:66" DEFAULT_ENTRY_TITLE = f"{ALIAS} {MODEL}" CREDENTIALS_HASH_LEGACY = "" +CONN_PARAMS_LEGACY = DeviceConnectionParameters( + DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor +) DEVICE_CONFIG_LEGACY = DeviceConfig(IP_ADDRESS) DEVICE_CONFIG_DICT_LEGACY = DEVICE_CONFIG_LEGACY.to_dict(exclude_credentials=True) CREDENTIALS = Credentials("foo", "bar") CREDENTIALS_HASH_AES = "AES/abcdefghijklmnopqrstuvabcdefghijklmnopqrstuv==" CREDENTIALS_HASH_KLAP = "KLAP/abcdefghijklmnopqrstuv==" +CONN_PARAMS_KLAP = DeviceConnectionParameters( + DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap +) DEVICE_CONFIG_KLAP = DeviceConfig( IP_ADDRESS, credentials=CREDENTIALS, - connection_type=DeviceConnectionParameters( - DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap - ), + connection_type=CONN_PARAMS_KLAP, uses_http=True, ) +CONN_PARAMS_AES = DeviceConnectionParameters( + DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Aes +) +AES_KEYS = {"private": "foo", "public": "bar"} DEVICE_CONFIG_AES = DeviceConfig( IP_ADDRESS2, credentials=CREDENTIALS, - connection_type=DeviceConnectionParameters( - DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Aes - ), + connection_type=CONN_PARAMS_AES, uses_http=True, + aes_keys=AES_KEYS, ) DEVICE_CONFIG_DICT_KLAP = DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True) DEVICE_CONFIG_DICT_AES = DEVICE_CONFIG_AES.to_dict(exclude_credentials=True) - CREATE_ENTRY_DATA_LEGACY = { CONF_HOST: IP_ADDRESS, CONF_ALIAS: ALIAS, CONF_MODEL: MODEL, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, + CONF_CONNECTION_PARAMETERS: CONN_PARAMS_LEGACY.to_dict(), + CONF_USES_HTTP: False, } CREATE_ENTRY_DATA_KLAP = { @@ -90,23 +99,18 @@ CREATE_ENTRY_DATA_KLAP = { CONF_ALIAS: ALIAS, CONF_MODEL: MODEL, CONF_CREDENTIALS_HASH: CREDENTIALS_HASH_KLAP, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, + CONF_CONNECTION_PARAMETERS: CONN_PARAMS_KLAP.to_dict(), + CONF_USES_HTTP: True, } CREATE_ENTRY_DATA_AES = { CONF_HOST: IP_ADDRESS2, CONF_ALIAS: ALIAS, CONF_MODEL: MODEL, CONF_CREDENTIALS_HASH: CREDENTIALS_HASH_AES, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_AES, + CONF_CONNECTION_PARAMETERS: CONN_PARAMS_AES.to_dict(), + CONF_USES_HTTP: True, + CONF_AES_KEYS: AES_KEYS, } -CONNECTION_TYPE_KLAP = DeviceConnectionParameters( - DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap -) -CONNECTION_TYPE_KLAP_DICT = CONNECTION_TYPE_KLAP.to_dict() -CONNECTION_TYPE_AES = DeviceConnectionParameters( - DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Aes -) -CONNECTION_TYPE_AES_DICT = CONNECTION_TYPE_AES.to_dict() def _load_feature_fixtures(): @@ -452,11 +456,11 @@ MODULE_TO_MOCK_GEN = { } -def _patch_discovery(device=None, no_device=False): +def _patch_discovery(device=None, no_device=False, ip_address=IP_ADDRESS): async def _discovery(*args, **kwargs): if no_device: return {} - return {IP_ADDRESS: _mocked_device()} + return {ip_address: device if device else _mocked_device()} return patch("homeassistant.components.tplink.Discover.discover", new=_discovery) diff --git a/tests/components/tplink/conftest.py b/tests/components/tplink/conftest.py index ee4530575ce..f1586ee4a0a 100644 --- a/tests/components/tplink/conftest.py +++ b/tests/components/tplink/conftest.py @@ -1,9 +1,9 @@ """tplink conftest.""" from collections.abc import Generator -import copy from unittest.mock import DEFAULT, AsyncMock, patch +from kasa import DeviceConfig import pytest from homeassistant.components.tplink import DOMAIN @@ -34,13 +34,13 @@ def mock_discovery(): discover_single=DEFAULT, ) as mock_discovery: device = _mocked_device( - device_config=copy.deepcopy(DEVICE_CONFIG_KLAP), + device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), credentials_hash=CREDENTIALS_HASH_KLAP, alias=None, ) devices = { "127.0.0.1": _mocked_device( - device_config=copy.deepcopy(DEVICE_CONFIG_KLAP), + device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), credentials_hash=CREDENTIALS_HASH_KLAP, alias=None, ) @@ -57,12 +57,12 @@ def mock_connect(): with patch("homeassistant.components.tplink.Device.connect") as mock_connect: devices = { IP_ADDRESS: _mocked_device( - device_config=DEVICE_CONFIG_KLAP, + device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()), credentials_hash=CREDENTIALS_HASH_KLAP, ip_address=IP_ADDRESS, ), IP_ADDRESS2: _mocked_device( - device_config=DEVICE_CONFIG_AES, + device_config=DeviceConfig.from_dict(DEVICE_CONFIG_AES.to_dict()), credentials_hash=CREDENTIALS_HASH_AES, mac=MAC_ADDRESS2, ip_address=IP_ADDRESS2, diff --git a/tests/components/tplink/test_config_flow.py b/tests/components/tplink/test_config_flow.py index f90eb985d38..7b24769c858 100644 --- a/tests/components/tplink/test_config_flow.py +++ b/tests/components/tplink/test_config_flow.py @@ -1,5 +1,6 @@ """Test the tplink config flow.""" +from contextlib import contextmanager import logging from unittest.mock import AsyncMock, patch @@ -17,7 +18,7 @@ from homeassistant.components.tplink import ( KasaException, ) from homeassistant.components.tplink.const import ( - CONF_CONNECTION_TYPE, + CONF_CONNECTION_PARAMETERS, CONF_CREDENTIALS_HASH, CONF_DEVICE_CONFIG, ) @@ -34,17 +35,21 @@ from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from . import ( + AES_KEYS, ALIAS, - CONNECTION_TYPE_KLAP_DICT, + CONN_PARAMS_AES, + CONN_PARAMS_KLAP, + CONN_PARAMS_LEGACY, CREATE_ENTRY_DATA_AES, CREATE_ENTRY_DATA_KLAP, CREATE_ENTRY_DATA_LEGACY, CREDENTIALS_HASH_AES, CREDENTIALS_HASH_KLAP, DEFAULT_ENTRY_TITLE, - DEVICE_CONFIG_DICT_AES, + DEVICE_CONFIG_AES, DEVICE_CONFIG_DICT_KLAP, - DEVICE_CONFIG_DICT_LEGACY, + DEVICE_CONFIG_KLAP, + DEVICE_CONFIG_LEGACY, DHCP_FORMATTED_MAC_ADDRESS, IP_ADDRESS, MAC_ADDRESS, @@ -59,9 +64,44 @@ from . import ( from tests.common import MockConfigEntry -async def test_discovery(hass: HomeAssistant) -> None: +@contextmanager +def override_side_effect(mock: AsyncMock, effect): + """Temporarily override a mock side effect and replace afterwards.""" + try: + default_side_effect = mock.side_effect + mock.side_effect = effect + yield mock + finally: + mock.side_effect = default_side_effect + + +@pytest.mark.parametrize( + ("device_config", "expected_entry_data", "credentials_hash"), + [ + pytest.param( + DEVICE_CONFIG_KLAP, CREATE_ENTRY_DATA_KLAP, CREDENTIALS_HASH_KLAP, id="KLAP" + ), + pytest.param( + DEVICE_CONFIG_AES, CREATE_ENTRY_DATA_AES, CREDENTIALS_HASH_AES, id="AES" + ), + pytest.param(DEVICE_CONFIG_LEGACY, CREATE_ENTRY_DATA_LEGACY, None, id="Legacy"), + ], +) +async def test_discovery( + hass: HomeAssistant, device_config, expected_entry_data, credentials_hash +) -> None: """Test setting up discovery.""" - with _patch_discovery(), _patch_single_discovery(), _patch_connect(): + ip_address = device_config.host + device = _mocked_device( + device_config=device_config, + credentials_hash=credentials_hash, + ip_address=ip_address, + ) + with ( + _patch_discovery(device, ip_address=ip_address), + _patch_single_discovery(device), + _patch_connect(device), + ): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -91,9 +131,9 @@ async def test_discovery(hass: HomeAssistant) -> None: assert not result2["errors"] with ( - _patch_discovery(), - _patch_single_discovery(), - _patch_connect(), + _patch_discovery(device, ip_address=ip_address), + _patch_single_discovery(device), + _patch_connect(device), patch(f"{MODULE}.async_setup", return_value=True) as mock_setup, patch(f"{MODULE}.async_setup_entry", return_value=True) as mock_setup_entry, ): @@ -105,7 +145,7 @@ async def test_discovery(hass: HomeAssistant) -> None: assert result3["type"] is FlowResultType.CREATE_ENTRY assert result3["title"] == DEFAULT_ENTRY_TITLE - assert result3["data"] == CREATE_ENTRY_DATA_LEGACY + assert result3["data"] == expected_entry_data mock_setup.assert_called_once() mock_setup_entry.assert_called_once() @@ -130,24 +170,25 @@ async def test_discovery_auth( ) -> None: """Test authenticated discovery.""" - mock_discovery["mock_device"].update.side_effect = AuthenticationError + mock_device = mock_connect["mock_devices"][IP_ADDRESS] + assert mock_device.config == DEVICE_CONFIG_KLAP - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: IP_ADDRESS, - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, - ) + with override_side_effect(mock_connect["connect"], AuthenticationError): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: IP_ADDRESS, + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: mock_device, + }, + ) await hass.async_block_till_done() assert result["type"] is FlowResultType.FORM assert result["step_id"] == "discovery_auth_confirm" assert not result["errors"] - mock_discovery["mock_device"].update.reset_mock(side_effect=True) result2 = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={ @@ -172,40 +213,43 @@ async def test_discovery_auth( ) async def test_discovery_auth_errors( hass: HomeAssistant, - mock_discovery: AsyncMock, mock_connect: AsyncMock, mock_init, error_type, errors_msg, error_placement, ) -> None: - """Test handling of discovery authentication errors.""" - mock_discovery["mock_device"].update.side_effect = AuthenticationError - default_connect_side_effect = mock_connect["connect"].side_effect - mock_connect["connect"].side_effect = error_type + """Test handling of discovery authentication errors. - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: IP_ADDRESS, - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, - ) - await hass.async_block_till_done() + Tests for errors received during credential + entry during discovery_auth_confirm. + """ + mock_device = mock_connect["mock_devices"][IP_ADDRESS] + + with override_side_effect(mock_connect["connect"], AuthenticationError): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: IP_ADDRESS, + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: mock_device, + }, + ) + await hass.async_block_till_done() assert result["type"] is FlowResultType.FORM assert result["step_id"] == "discovery_auth_confirm" assert not result["errors"] - result2 = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={ - CONF_USERNAME: "fake_username", - CONF_PASSWORD: "fake_password", - }, - ) + with override_side_effect(mock_connect["connect"], error_type): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) assert result2["type"] is FlowResultType.FORM assert result2["errors"] == {error_placement: errors_msg} @@ -213,7 +257,6 @@ async def test_discovery_auth_errors( await hass.async_block_till_done() - mock_connect["connect"].side_effect = default_connect_side_effect result3 = await hass.config_entries.flow.async_configure( result2["flow_id"], { @@ -228,29 +271,29 @@ async def test_discovery_auth_errors( async def test_discovery_new_credentials( hass: HomeAssistant, - mock_discovery: AsyncMock, mock_connect: AsyncMock, mock_init, ) -> None: """Test setting up discovery with new credentials.""" - mock_discovery["mock_device"].update.side_effect = AuthenticationError + mock_device = mock_connect["mock_devices"][IP_ADDRESS] - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: IP_ADDRESS, - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, - ) - await hass.async_block_till_done() + with override_side_effect(mock_connect["connect"], AuthenticationError): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: IP_ADDRESS, + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: mock_device, + }, + ) + await hass.async_block_till_done() assert result["type"] is FlowResultType.FORM assert result["step_id"] == "discovery_auth_confirm" assert not result["errors"] - assert mock_connect["connect"].call_count == 0 + assert mock_connect["connect"].call_count == 1 with patch( "homeassistant.components.tplink.config_flow.get_credentials", @@ -260,7 +303,7 @@ async def test_discovery_new_credentials( result["flow_id"], ) - assert mock_connect["connect"].call_count == 1 + assert mock_connect["connect"].call_count == 2 assert result2["type"] is FlowResultType.FORM assert result2["step_id"] == "discovery_confirm" @@ -277,48 +320,54 @@ async def test_discovery_new_credentials( async def test_discovery_new_credentials_invalid( hass: HomeAssistant, - mock_discovery: AsyncMock, mock_connect: AsyncMock, mock_init, ) -> None: """Test setting up discovery with new invalid credentials.""" - mock_discovery["mock_device"].update.side_effect = AuthenticationError - default_connect_side_effect = mock_connect["connect"].side_effect + mock_device = mock_connect["mock_devices"][IP_ADDRESS] - mock_connect["connect"].side_effect = AuthenticationError - - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: IP_ADDRESS, - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, - ) - await hass.async_block_till_done() + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + patch( + "homeassistant.components.tplink.config_flow.get_credentials", + return_value=None, + ), + override_side_effect(mock_connect["connect"], AuthenticationError), + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: IP_ADDRESS, + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: mock_device, + }, + ) + await hass.async_block_till_done() assert result["type"] is FlowResultType.FORM assert result["step_id"] == "discovery_auth_confirm" assert not result["errors"] - assert mock_connect["connect"].call_count == 0 + assert mock_connect["connect"].call_count == 1 - with patch( - "homeassistant.components.tplink.config_flow.get_credentials", - return_value=Credentials("fake_user", "fake_pass"), + with ( + patch( + "homeassistant.components.tplink.config_flow.get_credentials", + return_value=Credentials("fake_user", "fake_pass"), + ), + override_side_effect(mock_connect["connect"], AuthenticationError), ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], ) - assert mock_connect["connect"].call_count == 1 + assert mock_connect["connect"].call_count == 2 assert result2["type"] is FlowResultType.FORM assert result2["step_id"] == "discovery_auth_confirm" await hass.async_block_till_done() - mock_connect["connect"].side_effect = default_connect_side_effect result3 = await hass.config_entries.flow.async_configure( result2["flow_id"], { @@ -577,32 +626,30 @@ async def test_manual_auth_errors( assert not result["errors"] mock_discovery["mock_device"].update.side_effect = AuthenticationError - default_connect_side_effect = mock_connect["connect"].side_effect - mock_connect["connect"].side_effect = error_type - result2 = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={CONF_HOST: IP_ADDRESS} - ) + with override_side_effect(mock_connect["connect"], error_type): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_HOST: IP_ADDRESS} + ) assert result2["type"] is FlowResultType.FORM assert result2["step_id"] == "user_auth_confirm" assert not result2["errors"] await hass.async_block_till_done() - - result3 = await hass.config_entries.flow.async_configure( - result2["flow_id"], - user_input={ - CONF_USERNAME: "fake_username", - CONF_PASSWORD: "fake_password", - }, - ) - await hass.async_block_till_done() + with override_side_effect(mock_connect["connect"], error_type): + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], + user_input={ + CONF_USERNAME: "fake_username", + CONF_PASSWORD: "fake_password", + }, + ) + await hass.async_block_till_done() assert result3["type"] is FlowResultType.FORM assert result3["step_id"] == "user_auth_confirm" assert result3["errors"] == {error_placement: errors_msg} assert result3["description_placeholders"]["error"] == str(error_type) - mock_connect["connect"].side_effect = default_connect_side_effect result4 = await hass.config_entries.flow.async_configure( result3["flow_id"], { @@ -628,7 +675,7 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None: CONF_HOST: IP_ADDRESS, CONF_MAC: MAC_ADDRESS, CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, + CONF_DEVICE: _mocked_device(device_config=DEVICE_CONFIG_LEGACY), }, ) await hass.async_block_till_done() @@ -691,7 +738,7 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None: CONF_HOST: IP_ADDRESS, CONF_MAC: MAC_ADDRESS, CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, + CONF_DEVICE: _mocked_device(device_config=DEVICE_CONFIG_LEGACY), }, ), ], @@ -745,7 +792,7 @@ async def test_discovered_by_dhcp_or_discovery( CONF_HOST: IP_ADDRESS, CONF_MAC: MAC_ADDRESS, CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, + CONF_DEVICE: _mocked_device(device_config=DEVICE_CONFIG_LEGACY), }, ), ], @@ -775,9 +822,11 @@ async def test_integration_discovery_with_ip_change( mock_connect: AsyncMock, ) -> None: """Test reauth flow.""" - mock_connect["connect"].side_effect = KasaException() mock_config_entry.add_to_hass(hass) - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], KasaException()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() @@ -785,39 +834,57 @@ async def test_integration_discovery_with_ip_change( flows = hass.config_entries.flow.async_progress() assert len(flows) == 0 - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY - assert mock_config_entry.data[CONF_DEVICE_CONFIG].get(CONF_HOST) == "127.0.0.1" - - discovery_result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: "127.0.0.2", - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] + == CONN_PARAMS_LEGACY.to_dict() ) + assert mock_config_entry.data[CONF_HOST] == "127.0.0.1" + + mocked_device = _mocked_device(device_config=DEVICE_CONFIG_KLAP) + with override_side_effect(mock_connect["connect"], lambda *_, **__: mocked_device): + discovery_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: "127.0.0.2", + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: mocked_device, + }, + ) await hass.async_block_till_done() assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["reason"] == "already_configured" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" config = DeviceConfig.from_dict(DEVICE_CONFIG_DICT_KLAP) + # Do a reload here and check that the + # new config is picked up in setup_entry mock_connect["connect"].reset_mock(side_effect=True) bulb = _mocked_device( device_config=config, mac=mock_config_entry.unique_id, ) - mock_connect["connect"].return_value = bulb - await hass.config_entries.async_reload(mock_config_entry.entry_id) - await hass.async_block_till_done() + + with ( + patch( + "homeassistant.components.tplink.async_create_clientsession", + return_value="Foo", + ), + override_side_effect(mock_connect["connect"], lambda *_, **__: bulb), + ): + await hass.config_entries.async_reload(mock_config_entry.entry_id) + await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.LOADED # Check that init set the new host correctly before calling connect assert config.host == "127.0.0.1" config.host = "127.0.0.2" + config.uses_http = False # Not passed in to new config class + config.http_client = "Foo" mock_connect["connect"].assert_awaited_once_with(config=config) @@ -831,8 +898,6 @@ async def test_integration_discovery_with_connection_change( And that connection_hash is removed as it will be invalid. """ - mock_connect["connect"].side_effect = KasaException() - mock_config_entry = MockConfigEntry( title="TPLink", domain=DOMAIN, @@ -840,7 +905,10 @@ async def test_integration_discovery_with_connection_change( unique_id=MAC_ADDRESS2, ) mock_config_entry.add_to_hass(hass) - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], KasaException()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done(wait_background_tasks=True) @@ -854,43 +922,57 @@ async def test_integration_discovery_with_connection_change( == 0 ) assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_AES - assert mock_config_entry.data[CONF_DEVICE_CONFIG].get(CONF_HOST) == "127.0.0.2" + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_AES.to_dict() + ) assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_AES + mock_connect["connect"].reset_mock() NEW_DEVICE_CONFIG = { **DEVICE_CONFIG_DICT_KLAP, - CONF_CONNECTION_TYPE: CONNECTION_TYPE_KLAP_DICT, + "connection_type": CONN_PARAMS_KLAP.to_dict(), CONF_HOST: "127.0.0.2", } config = DeviceConfig.from_dict(NEW_DEVICE_CONFIG) # Reset the connect mock so when the config flow reloads the entry it succeeds - mock_connect["connect"].reset_mock(side_effect=True) + bulb = _mocked_device( device_config=config, mac=mock_config_entry.unique_id, ) - mock_connect["connect"].return_value = bulb - discovery_result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: "127.0.0.2", - CONF_MAC: MAC_ADDRESS2, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: NEW_DEVICE_CONFIG, - }, - ) + with ( + patch( + "homeassistant.components.tplink.async_create_clientsession", + return_value="Foo", + ), + override_side_effect(mock_connect["connect"], lambda *_, **__: bulb), + ): + discovery_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: "127.0.0.2", + CONF_MAC: MAC_ADDRESS2, + CONF_ALIAS: ALIAS, + CONF_DEVICE: bulb, + }, + ) await hass.async_block_till_done(wait_background_tasks=True) assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["reason"] == "already_configured" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == NEW_DEVICE_CONFIG + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert CREDENTIALS_HASH_AES not in mock_config_entry.data assert mock_config_entry.state is ConfigEntryState.LOADED + config.host = "127.0.0.2" + config.uses_http = False # Not passed in to new config class + config.http_client = "Foo" + config.aes_keys = AES_KEYS mock_connect["connect"].assert_awaited_once_with(config=config) @@ -901,17 +983,18 @@ async def test_dhcp_discovery_with_ip_change( mock_connect: AsyncMock, ) -> None: """Test dhcp discovery with an IP change.""" - mock_connect["connect"].side_effect = KasaException() mock_config_entry.add_to_hass(hass) - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], KasaException()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY flows = hass.config_entries.flow.async_progress() assert len(flows) == 0 - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY - assert mock_config_entry.data[CONF_DEVICE_CONFIG].get(CONF_HOST) == "127.0.0.1" + assert mock_config_entry.data[CONF_HOST] == "127.0.0.1" discovery_result = await hass.config_entries.flow.async_init( DOMAIN, @@ -966,8 +1049,7 @@ async def test_reauth_update_with_encryption_change( caplog: pytest.LogCaptureFixture, ) -> None: """Test reauth flow.""" - orig_side_effect = mock_connect["connect"].side_effect - mock_connect["connect"].side_effect = AuthenticationError() + mock_config_entry = MockConfigEntry( title="TPLink", domain=DOMAIN, @@ -975,10 +1057,15 @@ async def test_reauth_update_with_encryption_change( unique_id=MAC_ADDRESS2, ) mock_config_entry.add_to_hass(hass) - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_AES + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_AES.to_dict() + ) assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_AES - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], AuthenticationError()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR @@ -988,7 +1075,9 @@ async def test_reauth_update_with_encryption_change( assert len(flows) == 1 [result] = flows assert result["step_id"] == "reauth_confirm" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_AES + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_AES.to_dict() + ) assert CONF_CREDENTIALS_HASH not in mock_config_entry.data new_config = DeviceConfig( @@ -1005,7 +1094,6 @@ async def test_reauth_update_with_encryption_change( mock_connect["mock_devices"]["127.0.0.2"].config = new_config mock_connect["mock_devices"]["127.0.0.2"].credentials_hash = CREDENTIALS_HASH_KLAP - mock_connect["connect"].side_effect = orig_side_effect result2 = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={ @@ -1023,10 +1111,10 @@ async def test_reauth_update_with_encryption_change( assert result2["type"] is FlowResultType.ABORT assert result2["reason"] == "reauth_successful" assert mock_config_entry.state is ConfigEntryState.LOADED - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == { - **DEVICE_CONFIG_DICT_KLAP, - CONF_HOST: "127.0.0.2", - } + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) + assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_KLAP @@ -1037,9 +1125,11 @@ async def test_reauth_update_from_discovery( mock_connect: AsyncMock, ) -> None: """Test reauth flow.""" - mock_connect["connect"].side_effect = AuthenticationError mock_config_entry.add_to_hass(hass) - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], AuthenticationError()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() @@ -1049,22 +1139,32 @@ async def test_reauth_update_from_discovery( assert len(flows) == 1 [result] = flows assert result["step_id"] == "reauth_confirm" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY - - discovery_result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: IP_ADDRESS, - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] + == CONN_PARAMS_LEGACY.to_dict() ) + + device = _mocked_device( + device_config=DEVICE_CONFIG_KLAP, + mac=mock_config_entry.unique_id, + ) + with override_side_effect(mock_connect["connect"], lambda *_, **__: device): + discovery_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: IP_ADDRESS, + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: device, + }, + ) await hass.async_block_till_done() assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["reason"] == "already_configured" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) async def test_reauth_update_from_discovery_with_ip_change( @@ -1074,9 +1174,11 @@ async def test_reauth_update_from_discovery_with_ip_change( mock_connect: AsyncMock, ) -> None: """Test reauth flow.""" - mock_connect["connect"].side_effect = AuthenticationError() mock_config_entry.add_to_hass(hass) - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], AuthenticationError()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR @@ -1085,22 +1187,32 @@ async def test_reauth_update_from_discovery_with_ip_change( assert len(flows) == 1 [result] = flows assert result["step_id"] == "reauth_confirm" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY - - discovery_result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: "127.0.0.2", - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] + == CONN_PARAMS_LEGACY.to_dict() ) + + device = _mocked_device( + device_config=DEVICE_CONFIG_KLAP, + mac=mock_config_entry.unique_id, + ) + with override_side_effect(mock_connect["connect"], lambda *_, **__: device): + discovery_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: "127.0.0.2", + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: device, + }, + ) await hass.async_block_till_done() assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["reason"] == "already_configured" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" @@ -1111,8 +1223,8 @@ async def test_reauth_no_update_if_config_and_ip_the_same( mock_connect: AsyncMock, ) -> None: """Test reauth discovery does not update when the host and config are the same.""" - mock_connect["connect"].side_effect = AuthenticationError() mock_config_entry.add_to_hass(hass) + hass.config_entries.async_update_entry( mock_config_entry, data={ @@ -1120,30 +1232,40 @@ async def test_reauth_no_update_if_config_and_ip_the_same( CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, }, ) - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() + with override_side_effect(mock_connect["connect"], AuthenticationError()): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR flows = hass.config_entries.flow.async_progress() assert len(flows) == 1 [result] = flows assert result["step_id"] == "reauth_confirm" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP - - discovery_result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, - data={ - CONF_HOST: IP_ADDRESS, - CONF_MAC: MAC_ADDRESS, - CONF_ALIAS: ALIAS, - CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, - }, + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() ) + + device = _mocked_device( + device_config=DEVICE_CONFIG_KLAP, + mac=mock_config_entry.unique_id, + ) + with override_side_effect(mock_connect["connect"], lambda *_, **__: device): + discovery_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, + data={ + CONF_HOST: IP_ADDRESS, + CONF_MAC: MAC_ADDRESS, + CONF_ALIAS: ALIAS, + CONF_DEVICE: device, + }, + ) await hass.async_block_till_done() assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["reason"] == "already_configured" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) assert mock_config_entry.data[CONF_HOST] == IP_ADDRESS @@ -1241,17 +1363,15 @@ async def test_pick_device_errors( assert result2["step_id"] == "pick_device" assert not result2["errors"] - default_connect_side_effect = mock_connect["connect"].side_effect - mock_connect["connect"].side_effect = error_type - result3 = await hass.config_entries.flow.async_configure( - result2["flow_id"], - {CONF_DEVICE: MAC_ADDRESS}, - ) - await hass.async_block_till_done() + with override_side_effect(mock_connect["connect"], error_type): + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], + {CONF_DEVICE: MAC_ADDRESS}, + ) + await hass.async_block_till_done() assert result3["type"] == expected_flow if expected_flow != FlowResultType.ABORT: - mock_connect["connect"].side_effect = default_connect_side_effect result4 = await hass.config_entries.flow.async_configure( result3["flow_id"], user_input={ @@ -1300,17 +1420,17 @@ async def test_discovery_timeout_connect_legacy_error( DOMAIN, context={"source": config_entries.SOURCE_USER} ) mock_discovery["discover_single"].side_effect = TimeoutError - mock_connect["connect"].side_effect = KasaException await hass.async_block_till_done() assert result["type"] is FlowResultType.FORM assert result["step_id"] == "user" assert not result["errors"] assert mock_connect["connect"].call_count == 0 - result2 = await hass.config_entries.flow.async_configure( - result["flow_id"], {CONF_HOST: IP_ADDRESS} - ) - await hass.async_block_till_done() + with override_side_effect(mock_connect["connect"], KasaException): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_HOST: IP_ADDRESS} + ) + await hass.async_block_till_done() assert result2["type"] is FlowResultType.FORM assert result2["errors"] == {"base": "cannot_connect"} assert mock_connect["connect"].call_count == 1 @@ -1334,17 +1454,17 @@ async def test_reauth_update_other_flows( data={**CREATE_ENTRY_DATA_AES}, unique_id=MAC_ADDRESS2, ) - default_side_effect = mock_connect["connect"].side_effect - mock_connect["connect"].side_effect = AuthenticationError() mock_config_entry.add_to_hass(hass) mock_config_entry2.add_to_hass(hass) - with patch("homeassistant.components.tplink.Discover.discover", return_value={}): + with ( + patch("homeassistant.components.tplink.Discover.discover", return_value={}), + override_side_effect(mock_connect["connect"], AuthenticationError()), + ): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() assert mock_config_entry2.state is ConfigEntryState.SETUP_ERROR assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR - mock_connect["connect"].side_effect = default_side_effect await hass.async_block_till_done() @@ -1353,7 +1473,9 @@ async def test_reauth_update_other_flows( flows_by_entry_id = {flow["context"]["entry_id"]: flow for flow in flows} result = flows_by_entry_id[mock_config_entry.entry_id] assert result["step_id"] == "reauth_confirm" - assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP + assert ( + mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict() + ) result2 = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={ diff --git a/tests/components/tplink/test_init.py b/tests/components/tplink/test_init.py index 986aaebd170..dd01c381adf 100644 --- a/tests/components/tplink/test_init.py +++ b/tests/components/tplink/test_init.py @@ -4,6 +4,7 @@ from __future__ import annotations import copy from datetime import timedelta +from typing import Any from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch from freezegun.api import FrozenDateTimeFactory @@ -13,14 +14,18 @@ import pytest from homeassistant import setup from homeassistant.components import tplink from homeassistant.components.tplink.const import ( + CONF_AES_KEYS, + CONF_CONNECTION_PARAMETERS, CONF_CREDENTIALS_HASH, CONF_DEVICE_CONFIG, DOMAIN, ) from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState from homeassistant.const import ( + CONF_ALIAS, CONF_AUTHENTICATION, CONF_HOST, + CONF_MODEL, CONF_PASSWORD, CONF_USERNAME, STATE_ON, @@ -33,13 +38,20 @@ from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util from . import ( + ALIAS, + CREATE_ENTRY_DATA_AES, CREATE_ENTRY_DATA_KLAP, CREATE_ENTRY_DATA_LEGACY, + CREDENTIALS_HASH_AES, + CREDENTIALS_HASH_KLAP, + DEVICE_CONFIG_AES, DEVICE_CONFIG_KLAP, + DEVICE_CONFIG_LEGACY, DEVICE_ID, DEVICE_ID_MAC, IP_ADDRESS, MAC_ADDRESS, + MODEL, _mocked_device, _patch_connect, _patch_discovery, @@ -207,16 +219,21 @@ async def test_config_entry_with_stored_credentials( hass.data.setdefault(DOMAIN, {})[CONF_AUTHENTICATION] = auth mock_config_entry.add_to_hass(hass) - await hass.config_entries.async_setup(mock_config_entry.entry_id) + with patch( + "homeassistant.components.tplink.async_create_clientsession", return_value="Foo" + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.LOADED - config = DEVICE_CONFIG_KLAP + config = DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()) + config.uses_http = False + config.http_client = "Foo" assert config.credentials != stored_credentials config.credentials = stored_credentials mock_connect["connect"].assert_called_once_with(config=config) -async def test_config_entry_device_config_invalid( +async def test_config_entry_conn_params_invalid( hass: HomeAssistant, mock_discovery: AsyncMock, mock_connect: AsyncMock, @@ -224,7 +241,7 @@ async def test_config_entry_device_config_invalid( ) -> None: """Test that an invalid device config logs an error and loads the config entry.""" entry_data = copy.deepcopy(CREATE_ENTRY_DATA_KLAP) - entry_data[CONF_DEVICE_CONFIG] = {"foo": "bar"} + entry_data[CONF_CONNECTION_PARAMETERS] = {"foo": "bar"} mock_config_entry = MockConfigEntry( title="TPLink", domain=DOMAIN, @@ -237,7 +254,7 @@ async def test_config_entry_device_config_invalid( assert mock_config_entry.state is ConfigEntryState.LOADED assert ( - f"Invalid connection type dict for {IP_ADDRESS}: {entry_data.get(CONF_DEVICE_CONFIG)}" + f"Invalid connection parameters dict for {IP_ADDRESS}: {entry_data.get(CONF_CONNECTION_PARAMETERS)}" in caplog.text ) @@ -495,8 +512,9 @@ async def test_unlink_devices( } assert device_entries[0].identifiers == set(test_identifiers) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() + with patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 3): + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id) @@ -504,7 +522,7 @@ async def test_unlink_devices( assert device_entries[0].identifiers == set(expected_identifiers) assert entry.version == 1 - assert entry.minor_version == 4 + assert entry.minor_version == 3 assert update_msg in caplog.text assert "Migration to version 1.3 complete" in caplog.text @@ -545,6 +563,7 @@ async def test_move_credentials_hash( with ( patch("homeassistant.components.tplink.Device.connect", new=_connect), patch("homeassistant.components.tplink.PLATFORMS", []), + patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 4), ): await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() @@ -589,6 +608,7 @@ async def test_move_credentials_hash_auth_error( side_effect=AuthenticationError, ), patch("homeassistant.components.tplink.PLATFORMS", []), + patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 4), ): entry.add_to_hass(hass) await hass.config_entries.async_setup(entry.entry_id) @@ -631,6 +651,7 @@ async def test_move_credentials_hash_other_error( "homeassistant.components.tplink.Device.connect", side_effect=KasaException ), patch("homeassistant.components.tplink.PLATFORMS", []), + patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 4), ): entry.add_to_hass(hass) await hass.config_entries.async_setup(entry.entry_id) @@ -647,10 +668,8 @@ async def test_credentials_hash( hass: HomeAssistant, ) -> None: """Test credentials_hash used to call connect.""" - device_config = {**DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True)} entry_data = { **CREATE_ENTRY_DATA_KLAP, - CONF_DEVICE_CONFIG: device_config, CONF_CREDENTIALS_HASH: "theHash", } @@ -674,9 +693,7 @@ async def test_credentials_hash( await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED - assert CONF_CREDENTIALS_HASH not in entry.data[CONF_DEVICE_CONFIG] assert CONF_CREDENTIALS_HASH in entry.data - assert entry.data[CONF_DEVICE_CONFIG] == device_config assert entry.data[CONF_CREDENTIALS_HASH] == "theHash" @@ -684,10 +701,8 @@ async def test_credentials_hash_auth_error( hass: HomeAssistant, ) -> None: """Test credentials_hash is deleted after an auth failure.""" - device_config = {**DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True)} entry_data = { **CREATE_ENTRY_DATA_KLAP, - CONF_DEVICE_CONFIG: device_config, CONF_CREDENTIALS_HASH: "theHash", } @@ -700,6 +715,10 @@ async def test_credentials_hash_auth_error( with ( patch("homeassistant.components.tplink.PLATFORMS", []), + patch( + "homeassistant.components.tplink.async_create_clientsession", + return_value="Foo", + ), patch( "homeassistant.components.tplink.Device.connect", side_effect=AuthenticationError, @@ -712,6 +731,76 @@ async def test_credentials_hash_auth_error( expected_config = DeviceConfig.from_dict( DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True, credentials_hash="theHash") ) + expected_config.uses_http = False + expected_config.http_client = "Foo" connect_mock.assert_called_with(config=expected_config) assert entry.state is ConfigEntryState.SETUP_ERROR assert CONF_CREDENTIALS_HASH not in entry.data + + +@pytest.mark.parametrize( + ("device_config", "expected_entry_data", "credentials_hash"), + [ + pytest.param( + DEVICE_CONFIG_KLAP, CREATE_ENTRY_DATA_KLAP, CREDENTIALS_HASH_KLAP, id="KLAP" + ), + pytest.param( + DEVICE_CONFIG_AES, CREATE_ENTRY_DATA_AES, CREDENTIALS_HASH_AES, id="AES" + ), + pytest.param(DEVICE_CONFIG_LEGACY, CREATE_ENTRY_DATA_LEGACY, None, id="Legacy"), + ], +) +async def test_migrate_remove_device_config( + hass: HomeAssistant, + mock_connect: AsyncMock, + caplog: pytest.LogCaptureFixture, + device_config: DeviceConfig, + expected_entry_data: dict[str, Any], + credentials_hash: str, +) -> None: + """Test credentials hash moved to parent. + + As async_setup_entry will succeed the hash on the parent is updated + from the device. + """ + OLD_CREATE_ENTRY_DATA = { + CONF_HOST: expected_entry_data[CONF_HOST], + CONF_ALIAS: ALIAS, + CONF_MODEL: MODEL, + CONF_DEVICE_CONFIG: device_config.to_dict(exclude_credentials=True), + } + + entry = MockConfigEntry( + title="TPLink", + domain=DOMAIN, + data=OLD_CREATE_ENTRY_DATA, + entry_id="123456", + unique_id=MAC_ADDRESS, + version=1, + minor_version=4, + ) + entry.add_to_hass(hass) + + async def _connect(config): + config.credentials_hash = credentials_hash + config.aes_keys = expected_entry_data.get(CONF_AES_KEYS) + return _mocked_device(device_config=config, credentials_hash=credentials_hash) + + with ( + patch("homeassistant.components.tplink.Device.connect", new=_connect), + patch("homeassistant.components.tplink.PLATFORMS", []), + patch( + "homeassistant.components.tplink.async_create_clientsession", + return_value="Foo", + ), + patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 5), + ): + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + assert entry.minor_version == 5 + assert entry.state is ConfigEntryState.LOADED + assert CONF_DEVICE_CONFIG not in entry.data + assert entry.data == expected_entry_data + + assert "Migration to version 1.5 complete" in caplog.text