From 769923e8dd5880c7d3a82793b40bb4ae039cb95b Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Tue, 13 Apr 2021 08:18:51 -0400 Subject: [PATCH] Raise exception for invalid call to DeviceRegistry.async_get_or_create (#49038) * Raise exception instead of returning None for DeviceRegistry.async_get_or_create * fix entity_platform logic --- homeassistant/exceptions.py | 15 +++++++++++++++ homeassistant/helpers/device_registry.py | 12 +++++++++--- homeassistant/helpers/entity_platform.py | 12 +++++++++--- tests/helpers/test_device_registry.py | 20 ++++++++++++-------- 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index fba00e094cd..a081cfe3cc2 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -183,3 +183,18 @@ class MaxLengthExceeded(HomeAssistantError): self.value = value self.property_name = property_name self.max_length = max_length + + +class RequiredParameterMissing(HomeAssistantError): + """Raised when a required parameter is missing from a function call.""" + + def __init__(self, parameter_names: list[str]) -> None: + """Initialize error.""" + super().__init__( + self, + ( + "Call must include at least one of the following parameters: " + f"{', '.join(parameter_names)}" + ), + ) + self.parameter_names = parameter_names diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index e0e5130a94f..80c54ed296f 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -10,6 +10,7 @@ import attr from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.exceptions import RequiredParameterMissing from homeassistant.loader import bind_hass import homeassistant.util.uuid as uuid_util @@ -259,10 +260,10 @@ class DeviceRegistry: # To disable a device if it gets created disabled_by: str | None | UndefinedType = UNDEFINED, suggested_area: str | None | UndefinedType = UNDEFINED, - ) -> DeviceEntry | None: + ) -> DeviceEntry: """Get device. Create if it doesn't exist.""" if not identifiers and not connections: - return None + raise RequiredParameterMissing(["identifiers", "connections"]) if identifiers is None: identifiers = set() @@ -300,7 +301,7 @@ class DeviceRegistry: else: via_device_id = UNDEFINED - return self._async_update_device( + device = self._async_update_device( device.id, add_config_entry_id=config_entry_id, via_device_id=via_device_id, @@ -315,6 +316,11 @@ class DeviceRegistry: suggested_area=suggested_area, ) + # This is safe because _async_update_device will always return a device + # in this use case. + assert device + return device + @callback def async_update_device( self, diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 490a5a2298c..25996c81d9d 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -24,7 +24,11 @@ from homeassistant.core import ( split_entity_id, valid_entity_id, ) -from homeassistant.exceptions import HomeAssistantError, PlatformNotReady +from homeassistant.exceptions import ( + HomeAssistantError, + PlatformNotReady, + RequiredParameterMissing, +) from homeassistant.helpers import ( config_validation as cv, device_registry as dev_reg, @@ -434,9 +438,11 @@ class EntityPlatform: if key in device_info: processed_dev_info[key] = device_info[key] - device = device_registry.async_get_or_create(**processed_dev_info) - if device: + try: + device = device_registry.async_get_or_create(**processed_dev_info) device_id = device.id + except RequiredParameterMissing: + pass disabled_by: str | None = None if not entity.entity_registry_enabled_default: diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index c5328000269..1a768662fc7 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -6,6 +6,7 @@ import pytest from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, callback +from homeassistant.exceptions import RequiredParameterMissing from homeassistant.helpers import device_registry, entity_registry from tests.common import ( @@ -114,18 +115,21 @@ async def test_requirement_for_identifier_or_connection(registry): manufacturer="manufacturer", model="model", ) - entry3 = registry.async_get_or_create( - config_entry_id="1234", - connections=set(), - identifiers=set(), - manufacturer="manufacturer", - model="model", - ) assert len(registry.devices) == 2 assert entry assert entry2 - assert entry3 is None + + with pytest.raises(RequiredParameterMissing) as exc_info: + registry.async_get_or_create( + config_entry_id="1234", + connections=set(), + identifiers=set(), + manufacturer="manufacturer", + model="model", + ) + + assert exc_info.value.parameter_names == ["identifiers", "connections"] async def test_multiple_config_entries(registry):