diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index bfc96eabfdf..2946c8c3743 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -191,21 +191,6 @@ class MaxLengthExceeded(HomeAssistantError): self.max_length = max_length -class RequiredParameterMissing(HomeAssistantError): - """Raised when a required parameter is missing from a function call.""" - - def __init__(self, parameter_names: list[str]) -> None: - """Initialize error.""" - super().__init__( - self, - ( - "Call must include at least one of the following parameters: " - f"{', '.join(parameter_names)}" - ), - ) - self.parameter_names = parameter_names - - class DependencyError(HomeAssistantError): """Raised when dependencies cannot be setup.""" diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 79b4eac68d5..a59313ed886 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -6,13 +6,14 @@ from collections.abc import Coroutine, ValuesView import logging import time from typing import TYPE_CHECKING, Any, TypeVar, cast +from urllib.parse import urlparse import attr from homeassistant.backports.enum import StrEnum from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.exceptions import HomeAssistantError, RequiredParameterMissing +from homeassistant.exceptions import HomeAssistantError from homeassistant.util.json import format_unserializable_data import homeassistant.util.uuid as uuid_util @@ -26,6 +27,7 @@ if TYPE_CHECKING: from homeassistant.config_entries import ConfigEntry from . import entity_registry + from .entity import DeviceInfo _LOGGER = logging.getLogger(__name__) @@ -60,6 +62,39 @@ DISABLED_CONFIG_ENTRY = DeviceEntryDisabler.CONFIG_ENTRY.value DISABLED_INTEGRATION = DeviceEntryDisabler.INTEGRATION.value DISABLED_USER = DeviceEntryDisabler.USER.value +DEVICE_INFO_TYPES = { + # Device info is categorized by finding the first device info type which has all + # the keys of the device info. The link device info type must be kept first + # to make it preferred over primary. + "link": { + "connections", + "identifiers", + }, + "primary": { + "configuration_url", + "connections", + "entry_type", + "hw_version", + "identifiers", + "manufacturer", + "model", + "name", + "suggested_area", + "sw_version", + "via_device", + }, + "secondary": { + "connections", + "default_manufacturer", + "default_model", + "default_name", + # Used by Fritz + "via_device", + }, +} + +DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values())) + class DeviceEntryType(StrEnum): """Device entry type.""" @@ -67,6 +102,66 @@ class DeviceEntryType(StrEnum): SERVICE = "service" +class DeviceInfoError(HomeAssistantError): + """Raised when device info is invalid.""" + + def __init__(self, domain: str, device_info: DeviceInfo, message: str) -> None: + """Initialize error.""" + super().__init__( + f"Invalid device info {device_info} for '{domain}' config entry: {message}", + ) + self.device_info = device_info + self.domain = domain + + +def _validate_device_info( + config_entry: ConfigEntry | None, + device_info: DeviceInfo, +) -> str: + """Process a device info.""" + keys = set(device_info) + + # If no keys or not enough info to match up, abort + if not device_info.get("connections") and not device_info.get("identifiers"): + raise DeviceInfoError( + config_entry.domain if config_entry else "unknown", + device_info, + "device info must include at least one of identifiers or connections", + ) + + device_info_type: str | None = None + + # Find the first device info type which has all keys in the device info + for possible_type, allowed_keys in DEVICE_INFO_TYPES.items(): + if keys <= allowed_keys: + device_info_type = possible_type + break + + if device_info_type is None: + raise DeviceInfoError( + config_entry.domain if config_entry else "unknown", + device_info, + ( + "device info needs to either describe a device, " + "link to existing device or provide extra information." + ), + ) + + if (config_url := device_info.get("configuration_url")) is not None: + if type(config_url) is not str or urlparse(config_url).scheme not in [ + "http", + "https", + "homeassistant", + ]: + raise DeviceInfoError( + config_entry.domain if config_entry else "unknown", + device_info, + f"invalid configuration_url '{config_url}'", + ) + + return device_info_type + + @attr.s(slots=True, frozen=True) class DeviceEntry: """Device Registry Entry.""" @@ -338,7 +433,7 @@ class DeviceRegistry: *, config_entry_id: str, configuration_url: str | None | UndefinedType = UNDEFINED, - connections: set[tuple[str, str]] | None = None, + connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, default_manufacturer: str | None | UndefinedType = UNDEFINED, default_model: str | None | UndefinedType = UNDEFINED, default_name: str | None | UndefinedType = UNDEFINED, @@ -346,22 +441,47 @@ class DeviceRegistry: disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED, hw_version: str | None | UndefinedType = UNDEFINED, - identifiers: set[tuple[str, str]] | None = None, + identifiers: set[tuple[str, str]] | None | UndefinedType = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED, model: str | None | UndefinedType = UNDEFINED, name: str | None | UndefinedType = UNDEFINED, suggested_area: str | None | UndefinedType = UNDEFINED, sw_version: str | None | UndefinedType = UNDEFINED, - via_device: tuple[str, str] | None = None, + via_device: tuple[str, str] | None | UndefinedType = UNDEFINED, ) -> DeviceEntry: """Get device. Create if it doesn't exist.""" - if not identifiers and not connections: - raise RequiredParameterMissing(["identifiers", "connections"]) - if identifiers is None: + # Reconstruct a DeviceInfo dict from the arguments. + # When we upgrade to Python 3.12, we can change this method to instead + # accept kwargs typed as a DeviceInfo dict (PEP 692) + device_info: DeviceInfo = {} + for key, val in ( + ("configuration_url", configuration_url), + ("connections", connections), + ("default_manufacturer", default_manufacturer), + ("default_model", default_model), + ("default_name", default_name), + ("entry_type", entry_type), + ("hw_version", hw_version), + ("identifiers", identifiers), + ("manufacturer", manufacturer), + ("model", model), + ("name", name), + ("suggested_area", suggested_area), + ("sw_version", sw_version), + ("via_device", via_device), + ): + if val is UNDEFINED: + continue + device_info[key] = val # type: ignore[literal-required] + + config_entry = self.hass.config_entries.async_get_entry(config_entry_id) + device_info_type = _validate_device_info(config_entry, device_info) + + if identifiers is None or identifiers is UNDEFINED: identifiers = set() - if connections is None: + if connections is None or connections is UNDEFINED: connections = set() else: connections = _normalize_connections(connections) @@ -378,6 +498,13 @@ class DeviceRegistry: config_entry_id, connections, identifiers ) self.devices[device.id] = device + # If creating a new device, default to the config entry name + if ( + device_info_type == "primary" + and (not name or name is UNDEFINED) + and config_entry + ): + name = config_entry.title if default_manufacturer is not UNDEFINED and device.manufacturer is None: manufacturer = default_manufacturer @@ -388,7 +515,7 @@ class DeviceRegistry: if default_name is not UNDEFINED and device.name is None: name = default_name - if via_device is not None: + if via_device is not None and via_device is not UNDEFINED: via = self.async_get_device(identifiers={via_device}) via_device_id: str | UndefinedType = via.id if via else UNDEFINED else: diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index f97e509f486..067d6430c9f 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -7,7 +7,6 @@ from contextvars import ContextVar from datetime import datetime, timedelta from logging import Logger, getLogger from typing import TYPE_CHECKING, Any, Protocol -from urllib.parse import urlparse import voluptuous as vol @@ -48,7 +47,7 @@ from .issue_registry import IssueSeverity, async_create_issue from .typing import UNDEFINED, ConfigType, DiscoveryInfoType if TYPE_CHECKING: - from .entity import DeviceInfo, Entity + from .entity import Entity SLOW_SETUP_WARNING = 10 @@ -60,37 +59,6 @@ PLATFORM_NOT_READY_RETRIES = 10 DATA_ENTITY_PLATFORM = "entity_platform" PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds -DEVICE_INFO_TYPES = { - # Device info is categorized by finding the first device info type which has all - # the keys of the device info. The link device info type must be kept first - # to make it preferred over primary. - "link": { - "connections", - "identifiers", - }, - "primary": { - "configuration_url", - "connections", - "entry_type", - "hw_version", - "identifiers", - "manufacturer", - "model", - "name", - "suggested_area", - "sw_version", - "via_device", - }, - "secondary": { - "connections", - "default_manufacturer", - "default_model", - "default_name", - # Used by Fritz - "via_device", - }, -} - _LOGGER = getLogger(__name__) @@ -646,7 +614,14 @@ class EntityPlatform: return if self.config_entry and (device_info := entity.device_info): - device = self._async_process_device_info(device_info) + try: + device = dev_reg.async_get(self.hass).async_get_or_create( + config_entry_id=self.config_entry.entry_id, + **device_info, + ) + except dev_reg.DeviceInfoError as exc: + self.logger.error("Ignoring invalid device info: %s", str(exc)) + device = None else: device = None @@ -773,62 +748,6 @@ class EntityPlatform: await entity.add_to_platform_finish() - @callback - def _async_process_device_info( - self, device_info: DeviceInfo - ) -> dev_reg.DeviceEntry | None: - """Process a device info.""" - keys = set(device_info) - - # If no keys or not enough info to match up, abort - if len(keys & {"connections", "identifiers"}) == 0: - self.logger.error( - "Ignoring device info without identifiers or connections: %s", - device_info, - ) - return None - - device_info_type: str | None = None - - # Find the first device info type which has all keys in the device info - for possible_type, allowed_keys in DEVICE_INFO_TYPES.items(): - if keys <= allowed_keys: - device_info_type = possible_type - break - - if device_info_type is None: - self.logger.error( - "Device info for %s needs to either describe a device, " - "link to existing device or provide extra information.", - device_info, - ) - return None - - if (config_url := device_info.get("configuration_url")) is not None: - if type(config_url) is not str or urlparse(config_url).scheme not in [ - "http", - "https", - "homeassistant", - ]: - self.logger.error( - "Ignoring device info with invalid configuration_url '%s'", - config_url, - ) - return None - - assert self.config_entry is not None - - if device_info_type == "primary" and not device_info.get("name"): - device_info = { - **device_info, # type: ignore[misc] - "name": self.config_entry.title, - } - - return dev_reg.async_get(self.hass).async_get_or_create( - config_entry_id=self.config_entry.entry_id, - **device_info, - ) - async def async_reset(self) -> None: """Remove all entities and reset data. diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 7df5859f502..3e59b08cfa8 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -8,7 +8,7 @@ import pytest from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, HomeAssistant, callback -from homeassistant.exceptions import RequiredParameterMissing +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, device_registry as dr, @@ -118,7 +118,7 @@ async def test_requirement_for_identifier_or_connection( assert entry assert entry2 - with pytest.raises(RequiredParameterMissing) as exc_info: + with pytest.raises(HomeAssistantError): device_registry.async_get_or_create( config_entry_id="1234", connections=set(), @@ -127,8 +127,6 @@ async def test_requirement_for_identifier_or_connection( model="model", ) - assert exc_info.value.parameter_names == ["identifiers", "connections"] - async def test_multiple_config_entries(device_registry: dr.DeviceRegistry) -> None: """Make sure we do not get duplicate entries.""" @@ -1462,7 +1460,8 @@ async def test_get_or_create_empty_then_set_default_values( ) -> None: """Test creating an entry, then setting default name, model, manufacturer.""" entry = device_registry.async_get_or_create( - identifiers={("bridgeid", "0123")}, config_entry_id="1234" + config_entry_id="1234", + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) assert entry.name is None assert entry.model is None @@ -1470,7 +1469,7 @@ async def test_get_or_create_empty_then_set_default_values( entry = device_registry.async_get_or_create( config_entry_id="1234", - identifiers={("bridgeid", "0123")}, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, default_name="default name 1", default_model="default model 1", default_manufacturer="default manufacturer 1", @@ -1481,7 +1480,7 @@ async def test_get_or_create_empty_then_set_default_values( entry = device_registry.async_get_or_create( config_entry_id="1234", - identifiers={("bridgeid", "0123")}, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, default_name="default name 2", default_model="default model 2", default_manufacturer="default manufacturer 2", @@ -1496,7 +1495,8 @@ async def test_get_or_create_empty_then_update( ) -> None: """Test creating an entry, then setting name, model, manufacturer.""" entry = device_registry.async_get_or_create( - identifiers={("bridgeid", "0123")}, config_entry_id="1234" + config_entry_id="1234", + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) assert entry.name is None assert entry.model is None @@ -1504,7 +1504,7 @@ async def test_get_or_create_empty_then_update( entry = device_registry.async_get_or_create( config_entry_id="1234", - identifiers={("bridgeid", "0123")}, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, name="name 1", model="model 1", manufacturer="manufacturer 1", @@ -1515,7 +1515,7 @@ async def test_get_or_create_empty_then_update( entry = device_registry.async_get_or_create( config_entry_id="1234", - identifiers={("bridgeid", "0123")}, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, default_name="default name 1", default_model="default model 1", default_manufacturer="default manufacturer 1", @@ -1531,7 +1531,7 @@ async def test_get_or_create_sets_default_values( """Test creating an entry, then setting default name, model, manufacturer.""" entry = device_registry.async_get_or_create( config_entry_id="1234", - identifiers={("bridgeid", "0123")}, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, default_name="default name 1", default_model="default model 1", default_manufacturer="default manufacturer 1", @@ -1542,7 +1542,7 @@ async def test_get_or_create_sets_default_values( entry = device_registry.async_get_or_create( config_entry_id="1234", - identifiers={("bridgeid", "0123")}, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, default_name="default name 2", default_model="default model 2", default_manufacturer="default manufacturer 2", diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 7de6f70e793..0d9ee76ac62 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -972,6 +972,7 @@ async def _test_friendly_name( platform = MockPlatform(async_setup_entry=async_setup_entry) config_entry = MockConfigEntry(entry_id="super-mock-id") + config_entry.add_to_hass(hass) entity_platform = MockEntityPlatform( hass, platform_name=config_entry.domain, platform=platform ) diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index e07c3cb4753..1f7e579ea95 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -1830,6 +1830,7 @@ async def test_device_name_defaulting_config_entry( platform = MockPlatform(async_setup_entry=async_setup_entry) config_entry = MockConfigEntry(title=config_entry_title, entry_id="super-mock-id") + config_entry.add_to_hass(hass) entity_platform = MockEntityPlatform( hass, platform_name=config_entry.domain, platform=platform )