Don't allow creating device if config entry does not exist (#98157)

* Don't allow creating device if config entry does not exist

* Fix test

* Update test
This commit is contained in:
Erik Montnemery 2023-08-11 04:09:13 +02:00 committed by GitHub
parent 045c327928
commit 2e1a5ddf2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 9 deletions

View File

@ -158,7 +158,7 @@ class DeviceInfoError(HomeAssistantError):
def _validate_device_info(
config_entry: ConfigEntry | None,
config_entry: ConfigEntry,
device_info: DeviceInfo,
) -> str:
"""Process a device info."""
@ -167,7 +167,7 @@ def _validate_device_info(
# If no keys or not enough info to match up, abort
if not device_info.get("connections") and not device_info.get("identifiers"):
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
config_entry.domain,
device_info,
"device info must include at least one of identifiers or connections",
)
@ -182,7 +182,7 @@ def _validate_device_info(
if device_info_type is None:
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
config_entry.domain,
device_info,
(
"device info needs to either describe a device, "
@ -527,6 +527,10 @@ class DeviceRegistry:
device_info[key] = val # type: ignore[literal-required]
config_entry = self.hass.config_entries.async_get_entry(config_entry_id)
if config_entry is None:
raise HomeAssistantError(
f"Can't link device to unknown config entry {config_entry_id}"
)
device_info_type = _validate_device_info(config_entry, device_info)
if identifiers is None or identifiers is UNDEFINED:
@ -550,11 +554,7 @@ class DeviceRegistry:
)
self.devices[device.id] = device
# If creating a new device, default to the config entry name
if (
device_info_type == "primary"
and (not name or name is UNDEFINED)
and config_entry
):
if device_info_type == "primary" and (not name or name is UNDEFINED):
name = config_entry.title
if default_manufacturer is not UNDEFINED and device.manufacturer is None:

View File

@ -66,6 +66,7 @@ from .test_common import (
)
from tests.common import (
MockConfigEntry,
async_fire_mqtt_message,
async_fire_time_changed,
mock_restore_cache_with_extra_data,
@ -1123,9 +1124,11 @@ async def test_entity_device_info_with_hub(
) -> None:
"""Test MQTT sensor device registry integration."""
await mqtt_mock_entry()
other_config_entry = MockConfigEntry()
other_config_entry.add_to_hass(hass)
registry = dr.async_get(hass)
hub = registry.async_get_or_create(
config_entry_id="123",
config_entry_id=other_config_entry.entry_id,
connections=set(),
identifiers={("mqtt", "hub-id")},
manufacturer="manufacturer",