mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Raise exception for invalid call to DeviceRegistry.async_get_or_create (#49038)
* Raise exception instead of returning None for DeviceRegistry.async_get_or_create * fix entity_platform logic
This commit is contained in:
parent
2b79c91813
commit
769923e8dd
@ -183,3 +183,18 @@ class MaxLengthExceeded(HomeAssistantError):
|
|||||||
self.value = value
|
self.value = value
|
||||||
self.property_name = property_name
|
self.property_name = property_name
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredParameterMissing(HomeAssistantError):
|
||||||
|
"""Raised when a required parameter is missing from a function call."""
|
||||||
|
|
||||||
|
def __init__(self, parameter_names: list[str]) -> None:
|
||||||
|
"""Initialize error."""
|
||||||
|
super().__init__(
|
||||||
|
self,
|
||||||
|
(
|
||||||
|
"Call must include at least one of the following parameters: "
|
||||||
|
f"{', '.join(parameter_names)}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.parameter_names = parameter_names
|
||||||
|
@ -10,6 +10,7 @@ import attr
|
|||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
||||||
from homeassistant.core import Event, HomeAssistant, callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
|
from homeassistant.exceptions import RequiredParameterMissing
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
import homeassistant.util.uuid as uuid_util
|
import homeassistant.util.uuid as uuid_util
|
||||||
|
|
||||||
@ -259,10 +260,10 @@ class DeviceRegistry:
|
|||||||
# To disable a device if it gets created
|
# To disable a device if it gets created
|
||||||
disabled_by: str | None | UndefinedType = UNDEFINED,
|
disabled_by: str | None | UndefinedType = UNDEFINED,
|
||||||
suggested_area: str | None | UndefinedType = UNDEFINED,
|
suggested_area: str | None | UndefinedType = UNDEFINED,
|
||||||
) -> DeviceEntry | None:
|
) -> DeviceEntry:
|
||||||
"""Get device. Create if it doesn't exist."""
|
"""Get device. Create if it doesn't exist."""
|
||||||
if not identifiers and not connections:
|
if not identifiers and not connections:
|
||||||
return None
|
raise RequiredParameterMissing(["identifiers", "connections"])
|
||||||
|
|
||||||
if identifiers is None:
|
if identifiers is None:
|
||||||
identifiers = set()
|
identifiers = set()
|
||||||
@ -300,7 +301,7 @@ class DeviceRegistry:
|
|||||||
else:
|
else:
|
||||||
via_device_id = UNDEFINED
|
via_device_id = UNDEFINED
|
||||||
|
|
||||||
return self._async_update_device(
|
device = self._async_update_device(
|
||||||
device.id,
|
device.id,
|
||||||
add_config_entry_id=config_entry_id,
|
add_config_entry_id=config_entry_id,
|
||||||
via_device_id=via_device_id,
|
via_device_id=via_device_id,
|
||||||
@ -315,6 +316,11 @@ class DeviceRegistry:
|
|||||||
suggested_area=suggested_area,
|
suggested_area=suggested_area,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This is safe because _async_update_device will always return a device
|
||||||
|
# in this use case.
|
||||||
|
assert device
|
||||||
|
return device
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_update_device(
|
def async_update_device(
|
||||||
self,
|
self,
|
||||||
|
@ -24,7 +24,11 @@ from homeassistant.core import (
|
|||||||
split_entity_id,
|
split_entity_id,
|
||||||
valid_entity_id,
|
valid_entity_id,
|
||||||
)
|
)
|
||||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
from homeassistant.exceptions import (
|
||||||
|
HomeAssistantError,
|
||||||
|
PlatformNotReady,
|
||||||
|
RequiredParameterMissing,
|
||||||
|
)
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
config_validation as cv,
|
config_validation as cv,
|
||||||
device_registry as dev_reg,
|
device_registry as dev_reg,
|
||||||
@ -434,9 +438,11 @@ class EntityPlatform:
|
|||||||
if key in device_info:
|
if key in device_info:
|
||||||
processed_dev_info[key] = device_info[key]
|
processed_dev_info[key] = device_info[key]
|
||||||
|
|
||||||
|
try:
|
||||||
device = device_registry.async_get_or_create(**processed_dev_info)
|
device = device_registry.async_get_or_create(**processed_dev_info)
|
||||||
if device:
|
|
||||||
device_id = device.id
|
device_id = device.id
|
||||||
|
except RequiredParameterMissing:
|
||||||
|
pass
|
||||||
|
|
||||||
disabled_by: str | None = None
|
disabled_by: str | None = None
|
||||||
if not entity.entity_registry_enabled_default:
|
if not entity.entity_registry_enabled_default:
|
||||||
|
@ -6,6 +6,7 @@ import pytest
|
|||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
||||||
from homeassistant.core import CoreState, callback
|
from homeassistant.core import CoreState, callback
|
||||||
|
from homeassistant.exceptions import RequiredParameterMissing
|
||||||
from homeassistant.helpers import device_registry, entity_registry
|
from homeassistant.helpers import device_registry, entity_registry
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
@ -114,7 +115,13 @@ async def test_requirement_for_identifier_or_connection(registry):
|
|||||||
manufacturer="manufacturer",
|
manufacturer="manufacturer",
|
||||||
model="model",
|
model="model",
|
||||||
)
|
)
|
||||||
entry3 = registry.async_get_or_create(
|
|
||||||
|
assert len(registry.devices) == 2
|
||||||
|
assert entry
|
||||||
|
assert entry2
|
||||||
|
|
||||||
|
with pytest.raises(RequiredParameterMissing) as exc_info:
|
||||||
|
registry.async_get_or_create(
|
||||||
config_entry_id="1234",
|
config_entry_id="1234",
|
||||||
connections=set(),
|
connections=set(),
|
||||||
identifiers=set(),
|
identifiers=set(),
|
||||||
@ -122,10 +129,7 @@ async def test_requirement_for_identifier_or_connection(registry):
|
|||||||
model="model",
|
model="model",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(registry.devices) == 2
|
assert exc_info.value.parameter_names == ["identifiers", "connections"]
|
||||||
assert entry
|
|
||||||
assert entry2
|
|
||||||
assert entry3 is None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_multiple_config_entries(registry):
|
async def test_multiple_config_entries(registry):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user