Update tplink config to include aes keys (#125685)

This commit is contained in:
Steven B. 2024-09-10 19:52:10 +01:00 committed by GitHub
parent 44ca43c7ee
commit 40ee39f258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 598 additions and 322 deletions

View File

@ -26,6 +26,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_ALIAS, CONF_ALIAS,
CONF_AUTHENTICATION, CONF_AUTHENTICATION,
CONF_DEVICE,
CONF_HOST, CONF_HOST,
CONF_MAC, CONF_MAC,
CONF_MODEL, CONF_MODEL,
@ -44,8 +45,12 @@ from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
CONF_AES_KEYS,
CONF_CONFIG_ENTRY_MINOR_VERSION,
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH, CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG, CONF_DEVICE_CONFIG,
CONF_USES_HTTP,
CONNECT_TIMEOUT, CONNECT_TIMEOUT,
DISCOVERY_TIMEOUT, DISCOVERY_TIMEOUT,
DOMAIN, DOMAIN,
@ -85,9 +90,7 @@ def async_trigger_discovery(
CONF_ALIAS: device.alias or mac_alias(device.mac), CONF_ALIAS: device.alias or mac_alias(device.mac),
CONF_HOST: device.host, CONF_HOST: device.host,
CONF_MAC: formatted_mac, CONF_MAC: formatted_mac,
CONF_DEVICE_CONFIG: device.config.to_dict( CONF_DEVICE: device,
exclude_credentials=True,
),
}, },
) )
@ -136,25 +139,27 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
host: str = entry.data[CONF_HOST] host: str = entry.data[CONF_HOST]
credentials = await get_credentials(hass) credentials = await get_credentials(hass)
entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH) entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH)
entry_use_http = entry.data.get(CONF_USES_HTTP, False)
entry_aes_keys = entry.data.get(CONF_AES_KEYS)
config: DeviceConfig | None = None conn_params: Device.ConnectionParameters | None = None
if config_dict := entry.data.get(CONF_DEVICE_CONFIG): if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS):
try: try:
config = DeviceConfig.from_dict(config_dict) conn_params = Device.ConnectionParameters.from_dict(conn_params_dict)
except KasaException: except KasaException:
_LOGGER.warning( _LOGGER.warning(
"Invalid connection type dict for %s: %s", host, config_dict "Invalid connection parameters dict for %s: %s", host, conn_params_dict
) )
if not config: client = create_async_tplink_clientsession(hass) if entry_use_http else None
config = DeviceConfig(host) config = DeviceConfig(
else: host,
config.host = host timeout=CONNECT_TIMEOUT,
http_client=client,
config.timeout = CONNECT_TIMEOUT aes_keys=entry_aes_keys,
if config.uses_http is True: )
config.http_client = create_async_tplink_clientsession(hass) if conn_params:
config.connection_type = conn_params
# If we have in memory credentials use them otherwise check for credentials_hash # If we have in memory credentials use them otherwise check for credentials_hash
if credentials: if credentials:
config.credentials = credentials config.credentials = credentials
@ -173,14 +178,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
raise ConfigEntryNotReady from ex raise ConfigEntryNotReady from ex
device_credentials_hash = device.credentials_hash device_credentials_hash = device.credentials_hash
device_config_dict = device.config.to_dict(exclude_credentials=True)
# Do not store the credentials hash inside the device_config # We not need to update the connection parameters or the use_http here
device_config_dict.pop(CONF_CREDENTIALS_HASH, None) # because if they were wrong we would have failed to connect.
# Discovery will update those if necessary.
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if device_credentials_hash and device_credentials_hash != entry_credentials_hash: if device_credentials_hash and device_credentials_hash != entry_credentials_hash:
updates[CONF_CREDENTIALS_HASH] = device_credentials_hash updates[CONF_CREDENTIALS_HASH] = device_credentials_hash
if device_config_dict != config_dict: if entry_aes_keys != device.config.aes_keys:
updates[CONF_DEVICE_CONFIG] = device_config_dict updates[CONF_AES_KEYS] = device.config.aes_keys
if entry.data.get(CONF_ALIAS) != device.alias: if entry.data.get(CONF_ALIAS) != device.alias:
updates[CONF_ALIAS] = device.alias updates[CONF_ALIAS] = device.alias
if entry.data.get(CONF_MODEL) != device.model: if entry.data.get(CONF_MODEL) != device.model:
@ -307,12 +313,20 @@ def _device_id_is_mac_or_none(mac: str, device_ids: Iterable[str]) -> str | None
async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Migrate old entry.""" """Migrate old entry."""
version = config_entry.version entry_version = config_entry.version
minor_version = config_entry.minor_version entry_minor_version = config_entry.minor_version
# having a condition to check for the current version allows
# tests to be written per migration step.
config_flow_minor_version = CONF_CONFIG_ENTRY_MINOR_VERSION
_LOGGER.debug("Migrating from version %s.%s", version, minor_version) new_minor_version = 3
if (
if version == 1 and minor_version < 3: entry_version == 1
and entry_minor_version < new_minor_version <= config_flow_minor_version
):
_LOGGER.debug(
"Migrating from version %s.%s", entry_version, entry_minor_version
)
# Previously entities on child devices added themselves to the parent # Previously entities on child devices added themselves to the parent
# device and set their device id as identifiers along with mac # device and set their device id as identifiers along with mac
# as a connection which creates a single device entry linked by all # as a connection which creates a single device entry linked by all
@ -359,12 +373,19 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
new_identifiers, new_identifiers,
) )
minor_version = 3 hass.config_entries.async_update_entry(
hass.config_entries.async_update_entry(config_entry, minor_version=3) config_entry, minor_version=new_minor_version
)
_LOGGER.debug("Migration to version %s.%s complete", version, minor_version) _LOGGER.debug(
"Migration to version %s.%s complete", entry_version, new_minor_version
)
if version == 1 and minor_version == 3: new_minor_version = 4
if (
entry_version == 1
and entry_minor_version < new_minor_version <= config_flow_minor_version
):
# credentials_hash stored in the device_config should be moved to data. # credentials_hash stored in the device_config should be moved to data.
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG): if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG):
@ -372,15 +393,44 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
if credentials_hash := config_dict.pop(CONF_CREDENTIALS_HASH, None): if credentials_hash := config_dict.pop(CONF_CREDENTIALS_HASH, None):
updates[CONF_CREDENTIALS_HASH] = credentials_hash updates[CONF_CREDENTIALS_HASH] = credentials_hash
updates[CONF_DEVICE_CONFIG] = config_dict updates[CONF_DEVICE_CONFIG] = config_dict
minor_version = 4
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
config_entry, config_entry,
data={ data={
**config_entry.data, **config_entry.data,
**updates, **updates,
}, },
minor_version=minor_version, minor_version=new_minor_version,
)
_LOGGER.debug(
"Migration to version %s.%s complete", entry_version, new_minor_version
) )
_LOGGER.debug("Migration to version %s.%s complete", version, minor_version)
new_minor_version = 5
if (
entry_version == 1
and entry_minor_version < new_minor_version <= config_flow_minor_version
):
# complete device config no longer to be stored, only required
# attributes like connection parameters and aes_keys
updates = {}
entry_data = {
k: v for k, v in config_entry.data.items() if k != CONF_DEVICE_CONFIG
}
if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG):
assert isinstance(config_dict, dict)
if connection_parameters := config_dict.get("connection_type"):
updates[CONF_CONNECTION_PARAMETERS] = connection_parameters
if (use_http := config_dict.get(CONF_USES_HTTP)) is not None:
updates[CONF_USES_HTTP] = use_http
hass.config_entries.async_update_entry(
config_entry,
data={
**entry_data,
**updates,
},
minor_version=new_minor_version,
)
_LOGGER.debug(
"Migration to version %s.%s complete", entry_version, new_minor_version
)
return True return True

View File

@ -46,9 +46,11 @@ from . import (
set_credentials, set_credentials,
) )
from .const import ( from .const import (
CONF_CONNECTION_TYPE, CONF_AES_KEYS,
CONF_CONFIG_ENTRY_MINOR_VERSION,
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH, CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG, CONF_USES_HTTP,
CONNECT_TIMEOUT, CONNECT_TIMEOUT,
DOMAIN, DOMAIN,
) )
@ -64,7 +66,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for tplink.""" """Handle a config flow for tplink."""
VERSION = 1 VERSION = 1
MINOR_VERSION = 4 MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
reauth_entry: ConfigEntry | None = None reauth_entry: ConfigEntry | None = None
def __init__(self) -> None: def __init__(self) -> None:
@ -87,38 +89,43 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return await self._async_handle_discovery( return await self._async_handle_discovery(
discovery_info[CONF_HOST], discovery_info[CONF_HOST],
discovery_info[CONF_MAC], discovery_info[CONF_MAC],
discovery_info[CONF_DEVICE_CONFIG], discovery_info[CONF_DEVICE],
) )
@callback @callback
def _get_config_updates( def _get_config_updates(
self, entry: ConfigEntry, host: str, config: dict self, entry: ConfigEntry, host: str, device: Device | None
) -> dict | None: ) -> dict | None:
"""Return updates if the host or device config has changed.""" """Return updates if the host or device config has changed."""
entry_data = entry.data entry_data = entry.data
entry_config_dict = entry_data.get(CONF_DEVICE_CONFIG) updates: dict[str, Any] = {}
if entry_config_dict == config and entry_data[CONF_HOST] == host: new_connection_params = False
if entry_data[CONF_HOST] != host:
updates[CONF_HOST] = host
if device:
device_conn_params_dict = device.config.connection_type.to_dict()
entry_conn_params_dict = entry_data.get(CONF_CONNECTION_PARAMETERS)
if device_conn_params_dict != entry_conn_params_dict:
new_connection_params = True
updates[CONF_CONNECTION_PARAMETERS] = device_conn_params_dict
updates[CONF_USES_HTTP] = device.config.uses_http
if not updates:
return None return None
updates = {**entry.data, CONF_DEVICE_CONFIG: config, CONF_HOST: host} updates = {**entry.data, **updates}
# If the connection parameters have changed the credentials_hash will be invalid. # If the connection parameters have changed the credentials_hash will be invalid.
if ( if new_connection_params:
entry_config_dict
and isinstance(entry_config_dict, dict)
and entry_config_dict.get(CONF_CONNECTION_TYPE)
!= config.get(CONF_CONNECTION_TYPE)
):
updates.pop(CONF_CREDENTIALS_HASH, None) updates.pop(CONF_CREDENTIALS_HASH, None)
_LOGGER.debug( _LOGGER.debug(
"Connection type changed for %s from %s to: %s", "Connection type changed for %s from %s to: %s",
host, host,
entry_config_dict.get(CONF_CONNECTION_TYPE), entry_conn_params_dict,
config.get(CONF_CONNECTION_TYPE), device_conn_params_dict,
) )
return updates return updates
@callback @callback
def _update_config_if_entry_in_setup_error( def _update_config_if_entry_in_setup_error(
self, entry: ConfigEntry, host: str, config: dict self, entry: ConfigEntry, host: str, device: Device | None
) -> ConfigFlowResult | None: ) -> ConfigFlowResult | None:
"""If discovery encounters a device that is in SETUP_ERROR or SETUP_RETRY update the device config.""" """If discovery encounters a device that is in SETUP_ERROR or SETUP_RETRY update the device config."""
if entry.state not in ( if entry.state not in (
@ -126,7 +133,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
ConfigEntryState.SETUP_RETRY, ConfigEntryState.SETUP_RETRY,
): ):
return None return None
if updates := self._get_config_updates(entry, host, config): if updates := self._get_config_updates(entry, host, device):
return self.async_update_reload_and_abort( return self.async_update_reload_and_abort(
entry, entry,
data=updates, data=updates,
@ -135,19 +142,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return None return None
async def _async_handle_discovery( async def _async_handle_discovery(
self, host: str, formatted_mac: str, config: dict | None = None self, host: str, formatted_mac: str, device: Device | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle any discovery.""" """Handle any discovery."""
current_entry = await self.async_set_unique_id( current_entry = await self.async_set_unique_id(
formatted_mac, raise_on_progress=False formatted_mac, raise_on_progress=False
) )
if ( if current_entry and (
config result := self._update_config_if_entry_in_setup_error(
and current_entry current_entry, host, device
and (
result := self._update_config_if_entry_in_setup_error(
current_entry, host, config
)
) )
): ):
return result return result
@ -159,9 +162,13 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="already_in_progress") return self.async_abort(reason="already_in_progress")
credentials = await get_credentials(self.hass) credentials = await get_credentials(self.hass)
try: try:
await self._async_try_discover_and_update( if device:
host, credentials, raise_on_progress=True self._discovered_device = device
) await self._async_try_connect(device, credentials)
else:
await self._async_try_discover_and_update(
host, credentials, raise_on_progress=True
)
except AuthenticationError: except AuthenticationError:
return await self.async_step_discovery_auth_confirm() return await self.async_step_discovery_auth_confirm()
except KasaException: except KasaException:
@ -381,14 +388,15 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
# This is only ever called after a successful device update so we know that # This is only ever called after a successful device update so we know that
# the credential_hash is correct and should be saved. # the credential_hash is correct and should be saved.
self._abort_if_unique_id_configured(updates={CONF_HOST: device.host}) self._abort_if_unique_id_configured(updates={CONF_HOST: device.host})
data = { data: dict[str, Any] = {
CONF_HOST: device.host, CONF_HOST: device.host,
CONF_ALIAS: device.alias, CONF_ALIAS: device.alias,
CONF_MODEL: device.model, CONF_MODEL: device.model,
CONF_DEVICE_CONFIG: device.config.to_dict( CONF_CONNECTION_PARAMETERS: device.config.connection_type.to_dict(),
exclude_credentials=True, CONF_USES_HTTP: device.config.uses_http,
),
} }
if device.config.aes_keys:
data[CONF_AES_KEYS] = device.config.aes_keys
if device.credentials_hash: if device.credentials_hash:
data[CONF_CREDENTIALS_HASH] = device.credentials_hash data[CONF_CREDENTIALS_HASH] = device.credentials_hash
return self.async_create_entry( return self.async_create_entry(
@ -494,8 +502,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
placeholders["error"] = str(ex) placeholders["error"] = str(ex)
else: else:
await set_credentials(self.hass, username, password) await set_credentials(self.hass, username, password)
config = device.config.to_dict(exclude_credentials=True) if updates := self._get_config_updates(reauth_entry, host, device):
if updates := self._get_config_updates(reauth_entry, host, config):
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(
reauth_entry, data=updates reauth_entry, data=updates
) )

View File

@ -21,7 +21,11 @@ ATTR_TOTAL_ENERGY_KWH: Final = "total_energy_kwh"
CONF_DEVICE_CONFIG: Final = "device_config" CONF_DEVICE_CONFIG: Final = "device_config"
CONF_CREDENTIALS_HASH: Final = "credentials_hash" CONF_CREDENTIALS_HASH: Final = "credentials_hash"
CONF_CONNECTION_TYPE: Final = "connection_type" CONF_CONNECTION_PARAMETERS: Final = "connection_parameters"
CONF_USES_HTTP: Final = "uses_http"
CONF_AES_KEYS: Final = "aes_keys"
CONF_CONFIG_ENTRY_MINOR_VERSION: Final = 5
PLATFORMS: Final = [ PLATFORMS: Final = [
Platform.BINARY_SENSOR, Platform.BINARY_SENSOR,

View File

@ -21,11 +21,13 @@ from kasa.protocol import BaseProtocol
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from homeassistant.components.tplink import ( from homeassistant.components.tplink import (
CONF_AES_KEYS,
CONF_ALIAS, CONF_ALIAS,
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH, CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG,
CONF_HOST, CONF_HOST,
CONF_MODEL, CONF_MODEL,
CONF_USES_HTTP,
Credentials, Credentials,
) )
from homeassistant.components.tplink.const import DOMAIN from homeassistant.components.tplink.const import DOMAIN
@ -54,35 +56,42 @@ DHCP_FORMATTED_MAC_ADDRESS = MAC_ADDRESS.replace(":", "")
MAC_ADDRESS2 = "11:22:33:44:55:66" MAC_ADDRESS2 = "11:22:33:44:55:66"
DEFAULT_ENTRY_TITLE = f"{ALIAS} {MODEL}" DEFAULT_ENTRY_TITLE = f"{ALIAS} {MODEL}"
CREDENTIALS_HASH_LEGACY = "" CREDENTIALS_HASH_LEGACY = ""
CONN_PARAMS_LEGACY = DeviceConnectionParameters(
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor
)
DEVICE_CONFIG_LEGACY = DeviceConfig(IP_ADDRESS) DEVICE_CONFIG_LEGACY = DeviceConfig(IP_ADDRESS)
DEVICE_CONFIG_DICT_LEGACY = DEVICE_CONFIG_LEGACY.to_dict(exclude_credentials=True) DEVICE_CONFIG_DICT_LEGACY = DEVICE_CONFIG_LEGACY.to_dict(exclude_credentials=True)
CREDENTIALS = Credentials("foo", "bar") CREDENTIALS = Credentials("foo", "bar")
CREDENTIALS_HASH_AES = "AES/abcdefghijklmnopqrstuvabcdefghijklmnopqrstuv==" CREDENTIALS_HASH_AES = "AES/abcdefghijklmnopqrstuvabcdefghijklmnopqrstuv=="
CREDENTIALS_HASH_KLAP = "KLAP/abcdefghijklmnopqrstuv==" CREDENTIALS_HASH_KLAP = "KLAP/abcdefghijklmnopqrstuv=="
CONN_PARAMS_KLAP = DeviceConnectionParameters(
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap
)
DEVICE_CONFIG_KLAP = DeviceConfig( DEVICE_CONFIG_KLAP = DeviceConfig(
IP_ADDRESS, IP_ADDRESS,
credentials=CREDENTIALS, credentials=CREDENTIALS,
connection_type=DeviceConnectionParameters( connection_type=CONN_PARAMS_KLAP,
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap
),
uses_http=True, uses_http=True,
) )
CONN_PARAMS_AES = DeviceConnectionParameters(
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Aes
)
AES_KEYS = {"private": "foo", "public": "bar"}
DEVICE_CONFIG_AES = DeviceConfig( DEVICE_CONFIG_AES = DeviceConfig(
IP_ADDRESS2, IP_ADDRESS2,
credentials=CREDENTIALS, credentials=CREDENTIALS,
connection_type=DeviceConnectionParameters( connection_type=CONN_PARAMS_AES,
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Aes
),
uses_http=True, uses_http=True,
aes_keys=AES_KEYS,
) )
DEVICE_CONFIG_DICT_KLAP = DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True) DEVICE_CONFIG_DICT_KLAP = DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True)
DEVICE_CONFIG_DICT_AES = DEVICE_CONFIG_AES.to_dict(exclude_credentials=True) DEVICE_CONFIG_DICT_AES = DEVICE_CONFIG_AES.to_dict(exclude_credentials=True)
CREATE_ENTRY_DATA_LEGACY = { CREATE_ENTRY_DATA_LEGACY = {
CONF_HOST: IP_ADDRESS, CONF_HOST: IP_ADDRESS,
CONF_ALIAS: ALIAS, CONF_ALIAS: ALIAS,
CONF_MODEL: MODEL, CONF_MODEL: MODEL,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, CONF_CONNECTION_PARAMETERS: CONN_PARAMS_LEGACY.to_dict(),
CONF_USES_HTTP: False,
} }
CREATE_ENTRY_DATA_KLAP = { CREATE_ENTRY_DATA_KLAP = {
@ -90,23 +99,18 @@ CREATE_ENTRY_DATA_KLAP = {
CONF_ALIAS: ALIAS, CONF_ALIAS: ALIAS,
CONF_MODEL: MODEL, CONF_MODEL: MODEL,
CONF_CREDENTIALS_HASH: CREDENTIALS_HASH_KLAP, CONF_CREDENTIALS_HASH: CREDENTIALS_HASH_KLAP,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, CONF_CONNECTION_PARAMETERS: CONN_PARAMS_KLAP.to_dict(),
CONF_USES_HTTP: True,
} }
CREATE_ENTRY_DATA_AES = { CREATE_ENTRY_DATA_AES = {
CONF_HOST: IP_ADDRESS2, CONF_HOST: IP_ADDRESS2,
CONF_ALIAS: ALIAS, CONF_ALIAS: ALIAS,
CONF_MODEL: MODEL, CONF_MODEL: MODEL,
CONF_CREDENTIALS_HASH: CREDENTIALS_HASH_AES, CONF_CREDENTIALS_HASH: CREDENTIALS_HASH_AES,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_AES, CONF_CONNECTION_PARAMETERS: CONN_PARAMS_AES.to_dict(),
CONF_USES_HTTP: True,
CONF_AES_KEYS: AES_KEYS,
} }
CONNECTION_TYPE_KLAP = DeviceConnectionParameters(
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Klap
)
CONNECTION_TYPE_KLAP_DICT = CONNECTION_TYPE_KLAP.to_dict()
CONNECTION_TYPE_AES = DeviceConnectionParameters(
DeviceFamily.SmartTapoPlug, DeviceEncryptionType.Aes
)
CONNECTION_TYPE_AES_DICT = CONNECTION_TYPE_AES.to_dict()
def _load_feature_fixtures(): def _load_feature_fixtures():
@ -452,11 +456,11 @@ MODULE_TO_MOCK_GEN = {
} }
def _patch_discovery(device=None, no_device=False): def _patch_discovery(device=None, no_device=False, ip_address=IP_ADDRESS):
async def _discovery(*args, **kwargs): async def _discovery(*args, **kwargs):
if no_device: if no_device:
return {} return {}
return {IP_ADDRESS: _mocked_device()} return {ip_address: device if device else _mocked_device()}
return patch("homeassistant.components.tplink.Discover.discover", new=_discovery) return patch("homeassistant.components.tplink.Discover.discover", new=_discovery)

View File

@ -1,9 +1,9 @@
"""tplink conftest.""" """tplink conftest."""
from collections.abc import Generator from collections.abc import Generator
import copy
from unittest.mock import DEFAULT, AsyncMock, patch from unittest.mock import DEFAULT, AsyncMock, patch
from kasa import DeviceConfig
import pytest import pytest
from homeassistant.components.tplink import DOMAIN from homeassistant.components.tplink import DOMAIN
@ -34,13 +34,13 @@ def mock_discovery():
discover_single=DEFAULT, discover_single=DEFAULT,
) as mock_discovery: ) as mock_discovery:
device = _mocked_device( device = _mocked_device(
device_config=copy.deepcopy(DEVICE_CONFIG_KLAP), device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
credentials_hash=CREDENTIALS_HASH_KLAP, credentials_hash=CREDENTIALS_HASH_KLAP,
alias=None, alias=None,
) )
devices = { devices = {
"127.0.0.1": _mocked_device( "127.0.0.1": _mocked_device(
device_config=copy.deepcopy(DEVICE_CONFIG_KLAP), device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
credentials_hash=CREDENTIALS_HASH_KLAP, credentials_hash=CREDENTIALS_HASH_KLAP,
alias=None, alias=None,
) )
@ -57,12 +57,12 @@ def mock_connect():
with patch("homeassistant.components.tplink.Device.connect") as mock_connect: with patch("homeassistant.components.tplink.Device.connect") as mock_connect:
devices = { devices = {
IP_ADDRESS: _mocked_device( IP_ADDRESS: _mocked_device(
device_config=DEVICE_CONFIG_KLAP, device_config=DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict()),
credentials_hash=CREDENTIALS_HASH_KLAP, credentials_hash=CREDENTIALS_HASH_KLAP,
ip_address=IP_ADDRESS, ip_address=IP_ADDRESS,
), ),
IP_ADDRESS2: _mocked_device( IP_ADDRESS2: _mocked_device(
device_config=DEVICE_CONFIG_AES, device_config=DeviceConfig.from_dict(DEVICE_CONFIG_AES.to_dict()),
credentials_hash=CREDENTIALS_HASH_AES, credentials_hash=CREDENTIALS_HASH_AES,
mac=MAC_ADDRESS2, mac=MAC_ADDRESS2,
ip_address=IP_ADDRESS2, ip_address=IP_ADDRESS2,

View File

@ -1,5 +1,6 @@
"""Test the tplink config flow.""" """Test the tplink config flow."""
from contextlib import contextmanager
import logging import logging
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@ -17,7 +18,7 @@ from homeassistant.components.tplink import (
KasaException, KasaException,
) )
from homeassistant.components.tplink.const import ( from homeassistant.components.tplink.const import (
CONF_CONNECTION_TYPE, CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH, CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG, CONF_DEVICE_CONFIG,
) )
@ -34,17 +35,21 @@ from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from . import ( from . import (
AES_KEYS,
ALIAS, ALIAS,
CONNECTION_TYPE_KLAP_DICT, CONN_PARAMS_AES,
CONN_PARAMS_KLAP,
CONN_PARAMS_LEGACY,
CREATE_ENTRY_DATA_AES, CREATE_ENTRY_DATA_AES,
CREATE_ENTRY_DATA_KLAP, CREATE_ENTRY_DATA_KLAP,
CREATE_ENTRY_DATA_LEGACY, CREATE_ENTRY_DATA_LEGACY,
CREDENTIALS_HASH_AES, CREDENTIALS_HASH_AES,
CREDENTIALS_HASH_KLAP, CREDENTIALS_HASH_KLAP,
DEFAULT_ENTRY_TITLE, DEFAULT_ENTRY_TITLE,
DEVICE_CONFIG_DICT_AES, DEVICE_CONFIG_AES,
DEVICE_CONFIG_DICT_KLAP, DEVICE_CONFIG_DICT_KLAP,
DEVICE_CONFIG_DICT_LEGACY, DEVICE_CONFIG_KLAP,
DEVICE_CONFIG_LEGACY,
DHCP_FORMATTED_MAC_ADDRESS, DHCP_FORMATTED_MAC_ADDRESS,
IP_ADDRESS, IP_ADDRESS,
MAC_ADDRESS, MAC_ADDRESS,
@ -59,9 +64,44 @@ from . import (
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_discovery(hass: HomeAssistant) -> None: @contextmanager
def override_side_effect(mock: AsyncMock, effect):
"""Temporarily override a mock side effect and replace afterwards."""
try:
default_side_effect = mock.side_effect
mock.side_effect = effect
yield mock
finally:
mock.side_effect = default_side_effect
@pytest.mark.parametrize(
("device_config", "expected_entry_data", "credentials_hash"),
[
pytest.param(
DEVICE_CONFIG_KLAP, CREATE_ENTRY_DATA_KLAP, CREDENTIALS_HASH_KLAP, id="KLAP"
),
pytest.param(
DEVICE_CONFIG_AES, CREATE_ENTRY_DATA_AES, CREDENTIALS_HASH_AES, id="AES"
),
pytest.param(DEVICE_CONFIG_LEGACY, CREATE_ENTRY_DATA_LEGACY, None, id="Legacy"),
],
)
async def test_discovery(
hass: HomeAssistant, device_config, expected_entry_data, credentials_hash
) -> None:
"""Test setting up discovery.""" """Test setting up discovery."""
with _patch_discovery(), _patch_single_discovery(), _patch_connect(): ip_address = device_config.host
device = _mocked_device(
device_config=device_config,
credentials_hash=credentials_hash,
ip_address=ip_address,
)
with (
_patch_discovery(device, ip_address=ip_address),
_patch_single_discovery(device),
_patch_connect(device),
):
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
@ -91,9 +131,9 @@ async def test_discovery(hass: HomeAssistant) -> None:
assert not result2["errors"] assert not result2["errors"]
with ( with (
_patch_discovery(), _patch_discovery(device, ip_address=ip_address),
_patch_single_discovery(), _patch_single_discovery(device),
_patch_connect(), _patch_connect(device),
patch(f"{MODULE}.async_setup", return_value=True) as mock_setup, patch(f"{MODULE}.async_setup", return_value=True) as mock_setup,
patch(f"{MODULE}.async_setup_entry", return_value=True) as mock_setup_entry, patch(f"{MODULE}.async_setup_entry", return_value=True) as mock_setup_entry,
): ):
@ -105,7 +145,7 @@ async def test_discovery(hass: HomeAssistant) -> None:
assert result3["type"] is FlowResultType.CREATE_ENTRY assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["title"] == DEFAULT_ENTRY_TITLE assert result3["title"] == DEFAULT_ENTRY_TITLE
assert result3["data"] == CREATE_ENTRY_DATA_LEGACY assert result3["data"] == expected_entry_data
mock_setup.assert_called_once() mock_setup.assert_called_once()
mock_setup_entry.assert_called_once() mock_setup_entry.assert_called_once()
@ -130,24 +170,25 @@ async def test_discovery_auth(
) -> None: ) -> None:
"""Test authenticated discovery.""" """Test authenticated discovery."""
mock_discovery["mock_device"].update.side_effect = AuthenticationError mock_device = mock_connect["mock_devices"][IP_ADDRESS]
assert mock_device.config == DEVICE_CONFIG_KLAP
result = await hass.config_entries.flow.async_init( with override_side_effect(mock_connect["connect"], AuthenticationError):
DOMAIN, result = await hass.config_entries.flow.async_init(
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, DOMAIN,
data={ context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
CONF_HOST: IP_ADDRESS, data={
CONF_MAC: MAC_ADDRESS, CONF_HOST: IP_ADDRESS,
CONF_ALIAS: ALIAS, CONF_MAC: MAC_ADDRESS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, CONF_ALIAS: ALIAS,
}, CONF_DEVICE: mock_device,
) },
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "discovery_auth_confirm" assert result["step_id"] == "discovery_auth_confirm"
assert not result["errors"] assert not result["errors"]
mock_discovery["mock_device"].update.reset_mock(side_effect=True)
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
user_input={ user_input={
@ -172,40 +213,43 @@ async def test_discovery_auth(
) )
async def test_discovery_auth_errors( async def test_discovery_auth_errors(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock, mock_connect: AsyncMock,
mock_init, mock_init,
error_type, error_type,
errors_msg, errors_msg,
error_placement, error_placement,
) -> None: ) -> None:
"""Test handling of discovery authentication errors.""" """Test handling of discovery authentication errors.
mock_discovery["mock_device"].update.side_effect = AuthenticationError
default_connect_side_effect = mock_connect["connect"].side_effect
mock_connect["connect"].side_effect = error_type
result = await hass.config_entries.flow.async_init( Tests for errors received during credential
DOMAIN, entry during discovery_auth_confirm.
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, """
data={ mock_device = mock_connect["mock_devices"][IP_ADDRESS]
CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS, with override_side_effect(mock_connect["connect"], AuthenticationError):
CONF_ALIAS: ALIAS, result = await hass.config_entries.flow.async_init(
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, DOMAIN,
}, context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
) data={
await hass.async_block_till_done() CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE: mock_device,
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "discovery_auth_confirm" assert result["step_id"] == "discovery_auth_confirm"
assert not result["errors"] assert not result["errors"]
result2 = await hass.config_entries.flow.async_configure( with override_side_effect(mock_connect["connect"], error_type):
result["flow_id"], result2 = await hass.config_entries.flow.async_configure(
user_input={ result["flow_id"],
CONF_USERNAME: "fake_username", user_input={
CONF_PASSWORD: "fake_password", CONF_USERNAME: "fake_username",
}, CONF_PASSWORD: "fake_password",
) },
)
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {error_placement: errors_msg} assert result2["errors"] == {error_placement: errors_msg}
@ -213,7 +257,6 @@ async def test_discovery_auth_errors(
await hass.async_block_till_done() await hass.async_block_till_done()
mock_connect["connect"].side_effect = default_connect_side_effect
result3 = await hass.config_entries.flow.async_configure( result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], result2["flow_id"],
{ {
@ -228,29 +271,29 @@ async def test_discovery_auth_errors(
async def test_discovery_new_credentials( async def test_discovery_new_credentials(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock, mock_connect: AsyncMock,
mock_init, mock_init,
) -> None: ) -> None:
"""Test setting up discovery with new credentials.""" """Test setting up discovery with new credentials."""
mock_discovery["mock_device"].update.side_effect = AuthenticationError mock_device = mock_connect["mock_devices"][IP_ADDRESS]
result = await hass.config_entries.flow.async_init( with override_side_effect(mock_connect["connect"], AuthenticationError):
DOMAIN, result = await hass.config_entries.flow.async_init(
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, DOMAIN,
data={ context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
CONF_HOST: IP_ADDRESS, data={
CONF_MAC: MAC_ADDRESS, CONF_HOST: IP_ADDRESS,
CONF_ALIAS: ALIAS, CONF_MAC: MAC_ADDRESS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, CONF_ALIAS: ALIAS,
}, CONF_DEVICE: mock_device,
) },
await hass.async_block_till_done() )
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "discovery_auth_confirm" assert result["step_id"] == "discovery_auth_confirm"
assert not result["errors"] assert not result["errors"]
assert mock_connect["connect"].call_count == 0 assert mock_connect["connect"].call_count == 1
with patch( with patch(
"homeassistant.components.tplink.config_flow.get_credentials", "homeassistant.components.tplink.config_flow.get_credentials",
@ -260,7 +303,7 @@ async def test_discovery_new_credentials(
result["flow_id"], result["flow_id"],
) )
assert mock_connect["connect"].call_count == 1 assert mock_connect["connect"].call_count == 2
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["step_id"] == "discovery_confirm" assert result2["step_id"] == "discovery_confirm"
@ -277,48 +320,54 @@ async def test_discovery_new_credentials(
async def test_discovery_new_credentials_invalid( async def test_discovery_new_credentials_invalid(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock,
mock_connect: AsyncMock, mock_connect: AsyncMock,
mock_init, mock_init,
) -> None: ) -> None:
"""Test setting up discovery with new invalid credentials.""" """Test setting up discovery with new invalid credentials."""
mock_discovery["mock_device"].update.side_effect = AuthenticationError mock_device = mock_connect["mock_devices"][IP_ADDRESS]
default_connect_side_effect = mock_connect["connect"].side_effect
mock_connect["connect"].side_effect = AuthenticationError with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
result = await hass.config_entries.flow.async_init( patch(
DOMAIN, "homeassistant.components.tplink.config_flow.get_credentials",
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, return_value=None,
data={ ),
CONF_HOST: IP_ADDRESS, override_side_effect(mock_connect["connect"], AuthenticationError),
CONF_MAC: MAC_ADDRESS, ):
CONF_ALIAS: ALIAS, result = await hass.config_entries.flow.async_init(
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, DOMAIN,
}, context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
) data={
await hass.async_block_till_done() CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE: mock_device,
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "discovery_auth_confirm" assert result["step_id"] == "discovery_auth_confirm"
assert not result["errors"] assert not result["errors"]
assert mock_connect["connect"].call_count == 0 assert mock_connect["connect"].call_count == 1
with patch( with (
"homeassistant.components.tplink.config_flow.get_credentials", patch(
return_value=Credentials("fake_user", "fake_pass"), "homeassistant.components.tplink.config_flow.get_credentials",
return_value=Credentials("fake_user", "fake_pass"),
),
override_side_effect(mock_connect["connect"], AuthenticationError),
): ):
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
) )
assert mock_connect["connect"].call_count == 1 assert mock_connect["connect"].call_count == 2
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["step_id"] == "discovery_auth_confirm" assert result2["step_id"] == "discovery_auth_confirm"
await hass.async_block_till_done() await hass.async_block_till_done()
mock_connect["connect"].side_effect = default_connect_side_effect
result3 = await hass.config_entries.flow.async_configure( result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], result2["flow_id"],
{ {
@ -577,32 +626,30 @@ async def test_manual_auth_errors(
assert not result["errors"] assert not result["errors"]
mock_discovery["mock_device"].update.side_effect = AuthenticationError mock_discovery["mock_device"].update.side_effect = AuthenticationError
default_connect_side_effect = mock_connect["connect"].side_effect
mock_connect["connect"].side_effect = error_type
result2 = await hass.config_entries.flow.async_configure( with override_side_effect(mock_connect["connect"], error_type):
result["flow_id"], user_input={CONF_HOST: IP_ADDRESS} result2 = await hass.config_entries.flow.async_configure(
) result["flow_id"], user_input={CONF_HOST: IP_ADDRESS}
)
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["step_id"] == "user_auth_confirm" assert result2["step_id"] == "user_auth_confirm"
assert not result2["errors"] assert not result2["errors"]
await hass.async_block_till_done() await hass.async_block_till_done()
with override_side_effect(mock_connect["connect"], error_type):
result3 = await hass.config_entries.flow.async_configure( result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], result2["flow_id"],
user_input={ user_input={
CONF_USERNAME: "fake_username", CONF_USERNAME: "fake_username",
CONF_PASSWORD: "fake_password", CONF_PASSWORD: "fake_password",
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result3["type"] is FlowResultType.FORM assert result3["type"] is FlowResultType.FORM
assert result3["step_id"] == "user_auth_confirm" assert result3["step_id"] == "user_auth_confirm"
assert result3["errors"] == {error_placement: errors_msg} assert result3["errors"] == {error_placement: errors_msg}
assert result3["description_placeholders"]["error"] == str(error_type) assert result3["description_placeholders"]["error"] == str(error_type)
mock_connect["connect"].side_effect = default_connect_side_effect
result4 = await hass.config_entries.flow.async_configure( result4 = await hass.config_entries.flow.async_configure(
result3["flow_id"], result3["flow_id"],
{ {
@ -628,7 +675,7 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
CONF_HOST: IP_ADDRESS, CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS, CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS, CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, CONF_DEVICE: _mocked_device(device_config=DEVICE_CONFIG_LEGACY),
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -691,7 +738,7 @@ async def test_discovered_by_discovery_and_dhcp(hass: HomeAssistant) -> None:
CONF_HOST: IP_ADDRESS, CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS, CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS, CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, CONF_DEVICE: _mocked_device(device_config=DEVICE_CONFIG_LEGACY),
}, },
), ),
], ],
@ -745,7 +792,7 @@ async def test_discovered_by_dhcp_or_discovery(
CONF_HOST: IP_ADDRESS, CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS, CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS, CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_LEGACY, CONF_DEVICE: _mocked_device(device_config=DEVICE_CONFIG_LEGACY),
}, },
), ),
], ],
@ -775,9 +822,11 @@ async def test_integration_discovery_with_ip_change(
mock_connect: AsyncMock, mock_connect: AsyncMock,
) -> None: ) -> None:
"""Test reauth flow.""" """Test reauth flow."""
mock_connect["connect"].side_effect = KasaException()
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], KasaException()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -785,39 +834,57 @@ async def test_integration_discovery_with_ip_change(
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress()
assert len(flows) == 0 assert len(flows) == 0
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY assert (
assert mock_config_entry.data[CONF_DEVICE_CONFIG].get(CONF_HOST) == "127.0.0.1" mock_config_entry.data[CONF_CONNECTION_PARAMETERS]
== CONN_PARAMS_LEGACY.to_dict()
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: "127.0.0.2",
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP,
},
) )
assert mock_config_entry.data[CONF_HOST] == "127.0.0.1"
mocked_device = _mocked_device(device_config=DEVICE_CONFIG_KLAP)
with override_side_effect(mock_connect["connect"], lambda *_, **__: mocked_device):
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: "127.0.0.2",
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE: mocked_device,
},
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "already_configured" assert discovery_result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
)
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
config = DeviceConfig.from_dict(DEVICE_CONFIG_DICT_KLAP) config = DeviceConfig.from_dict(DEVICE_CONFIG_DICT_KLAP)
# Do a reload here and check that the
# new config is picked up in setup_entry
mock_connect["connect"].reset_mock(side_effect=True) mock_connect["connect"].reset_mock(side_effect=True)
bulb = _mocked_device( bulb = _mocked_device(
device_config=config, device_config=config,
mac=mock_config_entry.unique_id, mac=mock_config_entry.unique_id,
) )
mock_connect["connect"].return_value = bulb
await hass.config_entries.async_reload(mock_config_entry.entry_id) with (
await hass.async_block_till_done() patch(
"homeassistant.components.tplink.async_create_clientsession",
return_value="Foo",
),
override_side_effect(mock_connect["connect"], lambda *_, **__: bulb),
):
await hass.config_entries.async_reload(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.LOADED assert mock_config_entry.state is ConfigEntryState.LOADED
# Check that init set the new host correctly before calling connect # Check that init set the new host correctly before calling connect
assert config.host == "127.0.0.1" assert config.host == "127.0.0.1"
config.host = "127.0.0.2" config.host = "127.0.0.2"
config.uses_http = False # Not passed in to new config class
config.http_client = "Foo"
mock_connect["connect"].assert_awaited_once_with(config=config) mock_connect["connect"].assert_awaited_once_with(config=config)
@ -831,8 +898,6 @@ async def test_integration_discovery_with_connection_change(
And that connection_hash is removed as it will be invalid. And that connection_hash is removed as it will be invalid.
""" """
mock_connect["connect"].side_effect = KasaException()
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
title="TPLink", title="TPLink",
domain=DOMAIN, domain=DOMAIN,
@ -840,7 +905,10 @@ async def test_integration_discovery_with_connection_change(
unique_id=MAC_ADDRESS2, unique_id=MAC_ADDRESS2,
) )
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], KasaException()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
@ -854,43 +922,57 @@ async def test_integration_discovery_with_connection_change(
== 0 == 0
) )
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_AES assert (
assert mock_config_entry.data[CONF_DEVICE_CONFIG].get(CONF_HOST) == "127.0.0.2" mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_AES.to_dict()
)
assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_AES assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_AES
mock_connect["connect"].reset_mock()
NEW_DEVICE_CONFIG = { NEW_DEVICE_CONFIG = {
**DEVICE_CONFIG_DICT_KLAP, **DEVICE_CONFIG_DICT_KLAP,
CONF_CONNECTION_TYPE: CONNECTION_TYPE_KLAP_DICT, "connection_type": CONN_PARAMS_KLAP.to_dict(),
CONF_HOST: "127.0.0.2", CONF_HOST: "127.0.0.2",
} }
config = DeviceConfig.from_dict(NEW_DEVICE_CONFIG) config = DeviceConfig.from_dict(NEW_DEVICE_CONFIG)
# Reset the connect mock so when the config flow reloads the entry it succeeds # Reset the connect mock so when the config flow reloads the entry it succeeds
mock_connect["connect"].reset_mock(side_effect=True)
bulb = _mocked_device( bulb = _mocked_device(
device_config=config, device_config=config,
mac=mock_config_entry.unique_id, mac=mock_config_entry.unique_id,
) )
mock_connect["connect"].return_value = bulb
discovery_result = await hass.config_entries.flow.async_init( with (
DOMAIN, patch(
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, "homeassistant.components.tplink.async_create_clientsession",
data={ return_value="Foo",
CONF_HOST: "127.0.0.2", ),
CONF_MAC: MAC_ADDRESS2, override_side_effect(mock_connect["connect"], lambda *_, **__: bulb),
CONF_ALIAS: ALIAS, ):
CONF_DEVICE_CONFIG: NEW_DEVICE_CONFIG, discovery_result = await hass.config_entries.flow.async_init(
}, DOMAIN,
) context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: "127.0.0.2",
CONF_MAC: MAC_ADDRESS2,
CONF_ALIAS: ALIAS,
CONF_DEVICE: bulb,
},
)
await hass.async_block_till_done(wait_background_tasks=True) await hass.async_block_till_done(wait_background_tasks=True)
assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "already_configured" assert discovery_result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == NEW_DEVICE_CONFIG assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
)
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
assert CREDENTIALS_HASH_AES not in mock_config_entry.data assert CREDENTIALS_HASH_AES not in mock_config_entry.data
assert mock_config_entry.state is ConfigEntryState.LOADED assert mock_config_entry.state is ConfigEntryState.LOADED
config.host = "127.0.0.2"
config.uses_http = False # Not passed in to new config class
config.http_client = "Foo"
config.aes_keys = AES_KEYS
mock_connect["connect"].assert_awaited_once_with(config=config) mock_connect["connect"].assert_awaited_once_with(config=config)
@ -901,17 +983,18 @@ async def test_dhcp_discovery_with_ip_change(
mock_connect: AsyncMock, mock_connect: AsyncMock,
) -> None: ) -> None:
"""Test dhcp discovery with an IP change.""" """Test dhcp discovery with an IP change."""
mock_connect["connect"].side_effect = KasaException()
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], KasaException()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress()
assert len(flows) == 0 assert len(flows) == 0
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY assert mock_config_entry.data[CONF_HOST] == "127.0.0.1"
assert mock_config_entry.data[CONF_DEVICE_CONFIG].get(CONF_HOST) == "127.0.0.1"
discovery_result = await hass.config_entries.flow.async_init( discovery_result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
@ -966,8 +1049,7 @@ async def test_reauth_update_with_encryption_change(
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test reauth flow.""" """Test reauth flow."""
orig_side_effect = mock_connect["connect"].side_effect
mock_connect["connect"].side_effect = AuthenticationError()
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
title="TPLink", title="TPLink",
domain=DOMAIN, domain=DOMAIN,
@ -975,10 +1057,15 @@ async def test_reauth_update_with_encryption_change(
unique_id=MAC_ADDRESS2, unique_id=MAC_ADDRESS2,
) )
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_AES assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_AES.to_dict()
)
assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_AES assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_AES
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], AuthenticationError()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
@ -988,7 +1075,9 @@ async def test_reauth_update_with_encryption_change(
assert len(flows) == 1 assert len(flows) == 1
[result] = flows [result] = flows
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_AES assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_AES.to_dict()
)
assert CONF_CREDENTIALS_HASH not in mock_config_entry.data assert CONF_CREDENTIALS_HASH not in mock_config_entry.data
new_config = DeviceConfig( new_config = DeviceConfig(
@ -1005,7 +1094,6 @@ async def test_reauth_update_with_encryption_change(
mock_connect["mock_devices"]["127.0.0.2"].config = new_config mock_connect["mock_devices"]["127.0.0.2"].config = new_config
mock_connect["mock_devices"]["127.0.0.2"].credentials_hash = CREDENTIALS_HASH_KLAP mock_connect["mock_devices"]["127.0.0.2"].credentials_hash = CREDENTIALS_HASH_KLAP
mock_connect["connect"].side_effect = orig_side_effect
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
user_input={ user_input={
@ -1023,10 +1111,10 @@ async def test_reauth_update_with_encryption_change(
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT
assert result2["reason"] == "reauth_successful" assert result2["reason"] == "reauth_successful"
assert mock_config_entry.state is ConfigEntryState.LOADED assert mock_config_entry.state is ConfigEntryState.LOADED
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == { assert (
**DEVICE_CONFIG_DICT_KLAP, mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
CONF_HOST: "127.0.0.2", )
} assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_KLAP assert mock_config_entry.data[CONF_CREDENTIALS_HASH] == CREDENTIALS_HASH_KLAP
@ -1037,9 +1125,11 @@ async def test_reauth_update_from_discovery(
mock_connect: AsyncMock, mock_connect: AsyncMock,
) -> None: ) -> None:
"""Test reauth flow.""" """Test reauth flow."""
mock_connect["connect"].side_effect = AuthenticationError
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], AuthenticationError()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -1049,22 +1139,32 @@ async def test_reauth_update_from_discovery(
assert len(flows) == 1 assert len(flows) == 1
[result] = flows [result] = flows
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS]
discovery_result = await hass.config_entries.flow.async_init( == CONN_PARAMS_LEGACY.to_dict()
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP,
},
) )
device = _mocked_device(
device_config=DEVICE_CONFIG_KLAP,
mac=mock_config_entry.unique_id,
)
with override_side_effect(mock_connect["connect"], lambda *_, **__: device):
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE: device,
},
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "already_configured" assert discovery_result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
)
async def test_reauth_update_from_discovery_with_ip_change( async def test_reauth_update_from_discovery_with_ip_change(
@ -1074,9 +1174,11 @@ async def test_reauth_update_from_discovery_with_ip_change(
mock_connect: AsyncMock, mock_connect: AsyncMock,
) -> None: ) -> None:
"""Test reauth flow.""" """Test reauth flow."""
mock_connect["connect"].side_effect = AuthenticationError()
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], AuthenticationError()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
@ -1085,22 +1187,32 @@ async def test_reauth_update_from_discovery_with_ip_change(
assert len(flows) == 1 assert len(flows) == 1
[result] = flows [result] = flows
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_LEGACY assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS]
discovery_result = await hass.config_entries.flow.async_init( == CONN_PARAMS_LEGACY.to_dict()
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: "127.0.0.2",
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP,
},
) )
device = _mocked_device(
device_config=DEVICE_CONFIG_KLAP,
mac=mock_config_entry.unique_id,
)
with override_side_effect(mock_connect["connect"], lambda *_, **__: device):
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: "127.0.0.2",
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE: device,
},
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "already_configured" assert discovery_result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
)
assert mock_config_entry.data[CONF_HOST] == "127.0.0.2" assert mock_config_entry.data[CONF_HOST] == "127.0.0.2"
@ -1111,8 +1223,8 @@ async def test_reauth_no_update_if_config_and_ip_the_same(
mock_connect: AsyncMock, mock_connect: AsyncMock,
) -> None: ) -> None:
"""Test reauth discovery does not update when the host and config are the same.""" """Test reauth discovery does not update when the host and config are the same."""
mock_connect["connect"].side_effect = AuthenticationError()
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
mock_config_entry, mock_config_entry,
data={ data={
@ -1120,30 +1232,40 @@ async def test_reauth_no_update_if_config_and_ip_the_same(
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP, CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP,
}, },
) )
await hass.config_entries.async_setup(mock_config_entry.entry_id) with override_side_effect(mock_connect["connect"], AuthenticationError()):
await hass.async_block_till_done() await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1 assert len(flows) == 1
[result] = flows [result] = flows
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE_CONFIG: DEVICE_CONFIG_DICT_KLAP,
},
) )
device = _mocked_device(
device_config=DEVICE_CONFIG_KLAP,
mac=mock_config_entry.unique_id,
)
with override_side_effect(mock_connect["connect"], lambda *_, **__: device):
discovery_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={
CONF_HOST: IP_ADDRESS,
CONF_MAC: MAC_ADDRESS,
CONF_ALIAS: ALIAS,
CONF_DEVICE: device,
},
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert discovery_result["type"] is FlowResultType.ABORT assert discovery_result["type"] is FlowResultType.ABORT
assert discovery_result["reason"] == "already_configured" assert discovery_result["reason"] == "already_configured"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
)
assert mock_config_entry.data[CONF_HOST] == IP_ADDRESS assert mock_config_entry.data[CONF_HOST] == IP_ADDRESS
@ -1241,17 +1363,15 @@ async def test_pick_device_errors(
assert result2["step_id"] == "pick_device" assert result2["step_id"] == "pick_device"
assert not result2["errors"] assert not result2["errors"]
default_connect_side_effect = mock_connect["connect"].side_effect with override_side_effect(mock_connect["connect"], error_type):
mock_connect["connect"].side_effect = error_type result3 = await hass.config_entries.flow.async_configure(
result3 = await hass.config_entries.flow.async_configure( result2["flow_id"],
result2["flow_id"], {CONF_DEVICE: MAC_ADDRESS},
{CONF_DEVICE: MAC_ADDRESS}, )
) await hass.async_block_till_done()
await hass.async_block_till_done()
assert result3["type"] == expected_flow assert result3["type"] == expected_flow
if expected_flow != FlowResultType.ABORT: if expected_flow != FlowResultType.ABORT:
mock_connect["connect"].side_effect = default_connect_side_effect
result4 = await hass.config_entries.flow.async_configure( result4 = await hass.config_entries.flow.async_configure(
result3["flow_id"], result3["flow_id"],
user_input={ user_input={
@ -1300,17 +1420,17 @@ async def test_discovery_timeout_connect_legacy_error(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
mock_discovery["discover_single"].side_effect = TimeoutError mock_discovery["discover_single"].side_effect = TimeoutError
mock_connect["connect"].side_effect = KasaException
await hass.async_block_till_done() await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
assert not result["errors"] assert not result["errors"]
assert mock_connect["connect"].call_count == 0 assert mock_connect["connect"].call_count == 0
result2 = await hass.config_entries.flow.async_configure( with override_side_effect(mock_connect["connect"], KasaException):
result["flow_id"], {CONF_HOST: IP_ADDRESS} result2 = await hass.config_entries.flow.async_configure(
) result["flow_id"], {CONF_HOST: IP_ADDRESS}
await hass.async_block_till_done() )
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": "cannot_connect"} assert result2["errors"] == {"base": "cannot_connect"}
assert mock_connect["connect"].call_count == 1 assert mock_connect["connect"].call_count == 1
@ -1334,17 +1454,17 @@ async def test_reauth_update_other_flows(
data={**CREATE_ENTRY_DATA_AES}, data={**CREATE_ENTRY_DATA_AES},
unique_id=MAC_ADDRESS2, unique_id=MAC_ADDRESS2,
) )
default_side_effect = mock_connect["connect"].side_effect
mock_connect["connect"].side_effect = AuthenticationError()
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
mock_config_entry2.add_to_hass(hass) mock_config_entry2.add_to_hass(hass)
with patch("homeassistant.components.tplink.Discover.discover", return_value={}): with (
patch("homeassistant.components.tplink.Discover.discover", return_value={}),
override_side_effect(mock_connect["connect"], AuthenticationError()),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry2.state is ConfigEntryState.SETUP_ERROR assert mock_config_entry2.state is ConfigEntryState.SETUP_ERROR
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
mock_connect["connect"].side_effect = default_side_effect
await hass.async_block_till_done() await hass.async_block_till_done()
@ -1353,7 +1473,9 @@ async def test_reauth_update_other_flows(
flows_by_entry_id = {flow["context"]["entry_id"]: flow for flow in flows} flows_by_entry_id = {flow["context"]["entry_id"]: flow for flow in flows}
result = flows_by_entry_id[mock_config_entry.entry_id] result = flows_by_entry_id[mock_config_entry.entry_id]
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert mock_config_entry.data[CONF_DEVICE_CONFIG] == DEVICE_CONFIG_DICT_KLAP assert (
mock_config_entry.data[CONF_CONNECTION_PARAMETERS] == CONN_PARAMS_KLAP.to_dict()
)
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
user_input={ user_input={

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy import copy
from datetime import timedelta from datetime import timedelta
from typing import Any
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@ -13,14 +14,18 @@ import pytest
from homeassistant import setup from homeassistant import setup
from homeassistant.components import tplink from homeassistant.components import tplink
from homeassistant.components.tplink.const import ( from homeassistant.components.tplink.const import (
CONF_AES_KEYS,
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH, CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG, CONF_DEVICE_CONFIG,
DOMAIN, DOMAIN,
) )
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
from homeassistant.const import ( from homeassistant.const import (
CONF_ALIAS,
CONF_AUTHENTICATION, CONF_AUTHENTICATION,
CONF_HOST, CONF_HOST,
CONF_MODEL,
CONF_PASSWORD, CONF_PASSWORD,
CONF_USERNAME, CONF_USERNAME,
STATE_ON, STATE_ON,
@ -33,13 +38,20 @@ from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import ( from . import (
ALIAS,
CREATE_ENTRY_DATA_AES,
CREATE_ENTRY_DATA_KLAP, CREATE_ENTRY_DATA_KLAP,
CREATE_ENTRY_DATA_LEGACY, CREATE_ENTRY_DATA_LEGACY,
CREDENTIALS_HASH_AES,
CREDENTIALS_HASH_KLAP,
DEVICE_CONFIG_AES,
DEVICE_CONFIG_KLAP, DEVICE_CONFIG_KLAP,
DEVICE_CONFIG_LEGACY,
DEVICE_ID, DEVICE_ID,
DEVICE_ID_MAC, DEVICE_ID_MAC,
IP_ADDRESS, IP_ADDRESS,
MAC_ADDRESS, MAC_ADDRESS,
MODEL,
_mocked_device, _mocked_device,
_patch_connect, _patch_connect,
_patch_discovery, _patch_discovery,
@ -207,16 +219,21 @@ async def test_config_entry_with_stored_credentials(
hass.data.setdefault(DOMAIN, {})[CONF_AUTHENTICATION] = auth hass.data.setdefault(DOMAIN, {})[CONF_AUTHENTICATION] = auth
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id) with patch(
"homeassistant.components.tplink.async_create_clientsession", return_value="Foo"
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.LOADED assert mock_config_entry.state is ConfigEntryState.LOADED
config = DEVICE_CONFIG_KLAP config = DeviceConfig.from_dict(DEVICE_CONFIG_KLAP.to_dict())
config.uses_http = False
config.http_client = "Foo"
assert config.credentials != stored_credentials assert config.credentials != stored_credentials
config.credentials = stored_credentials config.credentials = stored_credentials
mock_connect["connect"].assert_called_once_with(config=config) mock_connect["connect"].assert_called_once_with(config=config)
async def test_config_entry_device_config_invalid( async def test_config_entry_conn_params_invalid(
hass: HomeAssistant, hass: HomeAssistant,
mock_discovery: AsyncMock, mock_discovery: AsyncMock,
mock_connect: AsyncMock, mock_connect: AsyncMock,
@ -224,7 +241,7 @@ async def test_config_entry_device_config_invalid(
) -> None: ) -> None:
"""Test that an invalid device config logs an error and loads the config entry.""" """Test that an invalid device config logs an error and loads the config entry."""
entry_data = copy.deepcopy(CREATE_ENTRY_DATA_KLAP) entry_data = copy.deepcopy(CREATE_ENTRY_DATA_KLAP)
entry_data[CONF_DEVICE_CONFIG] = {"foo": "bar"} entry_data[CONF_CONNECTION_PARAMETERS] = {"foo": "bar"}
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
title="TPLink", title="TPLink",
domain=DOMAIN, domain=DOMAIN,
@ -237,7 +254,7 @@ async def test_config_entry_device_config_invalid(
assert mock_config_entry.state is ConfigEntryState.LOADED assert mock_config_entry.state is ConfigEntryState.LOADED
assert ( assert (
f"Invalid connection type dict for {IP_ADDRESS}: {entry_data.get(CONF_DEVICE_CONFIG)}" f"Invalid connection parameters dict for {IP_ADDRESS}: {entry_data.get(CONF_CONNECTION_PARAMETERS)}"
in caplog.text in caplog.text
) )
@ -495,8 +512,9 @@ async def test_unlink_devices(
} }
assert device_entries[0].identifiers == set(test_identifiers) assert device_entries[0].identifiers == set(test_identifiers)
await hass.config_entries.async_setup(entry.entry_id) with patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 3):
await hass.async_block_till_done() await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id) device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id)
@ -504,7 +522,7 @@ async def test_unlink_devices(
assert device_entries[0].identifiers == set(expected_identifiers) assert device_entries[0].identifiers == set(expected_identifiers)
assert entry.version == 1 assert entry.version == 1
assert entry.minor_version == 4 assert entry.minor_version == 3
assert update_msg in caplog.text assert update_msg in caplog.text
assert "Migration to version 1.3 complete" in caplog.text assert "Migration to version 1.3 complete" in caplog.text
@ -545,6 +563,7 @@ async def test_move_credentials_hash(
with ( with (
patch("homeassistant.components.tplink.Device.connect", new=_connect), patch("homeassistant.components.tplink.Device.connect", new=_connect),
patch("homeassistant.components.tplink.PLATFORMS", []), patch("homeassistant.components.tplink.PLATFORMS", []),
patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 4),
): ):
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -589,6 +608,7 @@ async def test_move_credentials_hash_auth_error(
side_effect=AuthenticationError, side_effect=AuthenticationError,
), ),
patch("homeassistant.components.tplink.PLATFORMS", []), patch("homeassistant.components.tplink.PLATFORMS", []),
patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 4),
): ):
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
@ -631,6 +651,7 @@ async def test_move_credentials_hash_other_error(
"homeassistant.components.tplink.Device.connect", side_effect=KasaException "homeassistant.components.tplink.Device.connect", side_effect=KasaException
), ),
patch("homeassistant.components.tplink.PLATFORMS", []), patch("homeassistant.components.tplink.PLATFORMS", []),
patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 4),
): ):
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
@ -647,10 +668,8 @@ async def test_credentials_hash(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test credentials_hash used to call connect.""" """Test credentials_hash used to call connect."""
device_config = {**DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True)}
entry_data = { entry_data = {
**CREATE_ENTRY_DATA_KLAP, **CREATE_ENTRY_DATA_KLAP,
CONF_DEVICE_CONFIG: device_config,
CONF_CREDENTIALS_HASH: "theHash", CONF_CREDENTIALS_HASH: "theHash",
} }
@ -674,9 +693,7 @@ async def test_credentials_hash(
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
assert CONF_CREDENTIALS_HASH not in entry.data[CONF_DEVICE_CONFIG]
assert CONF_CREDENTIALS_HASH in entry.data assert CONF_CREDENTIALS_HASH in entry.data
assert entry.data[CONF_DEVICE_CONFIG] == device_config
assert entry.data[CONF_CREDENTIALS_HASH] == "theHash" assert entry.data[CONF_CREDENTIALS_HASH] == "theHash"
@ -684,10 +701,8 @@ async def test_credentials_hash_auth_error(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test credentials_hash is deleted after an auth failure.""" """Test credentials_hash is deleted after an auth failure."""
device_config = {**DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True)}
entry_data = { entry_data = {
**CREATE_ENTRY_DATA_KLAP, **CREATE_ENTRY_DATA_KLAP,
CONF_DEVICE_CONFIG: device_config,
CONF_CREDENTIALS_HASH: "theHash", CONF_CREDENTIALS_HASH: "theHash",
} }
@ -700,6 +715,10 @@ async def test_credentials_hash_auth_error(
with ( with (
patch("homeassistant.components.tplink.PLATFORMS", []), patch("homeassistant.components.tplink.PLATFORMS", []),
patch(
"homeassistant.components.tplink.async_create_clientsession",
return_value="Foo",
),
patch( patch(
"homeassistant.components.tplink.Device.connect", "homeassistant.components.tplink.Device.connect",
side_effect=AuthenticationError, side_effect=AuthenticationError,
@ -712,6 +731,76 @@ async def test_credentials_hash_auth_error(
expected_config = DeviceConfig.from_dict( expected_config = DeviceConfig.from_dict(
DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True, credentials_hash="theHash") DEVICE_CONFIG_KLAP.to_dict(exclude_credentials=True, credentials_hash="theHash")
) )
expected_config.uses_http = False
expected_config.http_client = "Foo"
connect_mock.assert_called_with(config=expected_config) connect_mock.assert_called_with(config=expected_config)
assert entry.state is ConfigEntryState.SETUP_ERROR assert entry.state is ConfigEntryState.SETUP_ERROR
assert CONF_CREDENTIALS_HASH not in entry.data assert CONF_CREDENTIALS_HASH not in entry.data
@pytest.mark.parametrize(
("device_config", "expected_entry_data", "credentials_hash"),
[
pytest.param(
DEVICE_CONFIG_KLAP, CREATE_ENTRY_DATA_KLAP, CREDENTIALS_HASH_KLAP, id="KLAP"
),
pytest.param(
DEVICE_CONFIG_AES, CREATE_ENTRY_DATA_AES, CREDENTIALS_HASH_AES, id="AES"
),
pytest.param(DEVICE_CONFIG_LEGACY, CREATE_ENTRY_DATA_LEGACY, None, id="Legacy"),
],
)
async def test_migrate_remove_device_config(
hass: HomeAssistant,
mock_connect: AsyncMock,
caplog: pytest.LogCaptureFixture,
device_config: DeviceConfig,
expected_entry_data: dict[str, Any],
credentials_hash: str,
) -> None:
"""Test credentials hash moved to parent.
As async_setup_entry will succeed the hash on the parent is updated
from the device.
"""
OLD_CREATE_ENTRY_DATA = {
CONF_HOST: expected_entry_data[CONF_HOST],
CONF_ALIAS: ALIAS,
CONF_MODEL: MODEL,
CONF_DEVICE_CONFIG: device_config.to_dict(exclude_credentials=True),
}
entry = MockConfigEntry(
title="TPLink",
domain=DOMAIN,
data=OLD_CREATE_ENTRY_DATA,
entry_id="123456",
unique_id=MAC_ADDRESS,
version=1,
minor_version=4,
)
entry.add_to_hass(hass)
async def _connect(config):
config.credentials_hash = credentials_hash
config.aes_keys = expected_entry_data.get(CONF_AES_KEYS)
return _mocked_device(device_config=config, credentials_hash=credentials_hash)
with (
patch("homeassistant.components.tplink.Device.connect", new=_connect),
patch("homeassistant.components.tplink.PLATFORMS", []),
patch(
"homeassistant.components.tplink.async_create_clientsession",
return_value="Foo",
),
patch("homeassistant.components.tplink.CONF_CONFIG_ENTRY_MINOR_VERSION", 5),
):
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.minor_version == 5
assert entry.state is ConfigEntryState.LOADED
assert CONF_DEVICE_CONFIG not in entry.data
assert entry.data == expected_entry_data
assert "Migration to version 1.5 complete" in caplog.text