mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Store capabilities and supported features in entity registry, restore registered entities on startup (#30094)
* Store capabilities and supported features in entity registry * Restore states at startup * Restore non-disabled entities on HA start * Fix test * Pass device class from entity platform * Clean up restored entities from state machine * Fix Z-Wave test?
This commit is contained in:
parent
2c1a7a54cd
commit
bb14a083f0
@ -311,7 +311,9 @@ class Entity(ABC):
|
||||
|
||||
start = timer()
|
||||
|
||||
attr = self.capability_attributes or {}
|
||||
attr = self.capability_attributes
|
||||
attr = dict(attr) if attr else {}
|
||||
|
||||
if not self.available:
|
||||
state = STATE_UNAVAILABLE
|
||||
else:
|
||||
|
@ -347,6 +347,9 @@ class EntityPlatform:
|
||||
device_id=device_id,
|
||||
known_object_ids=self.entities.keys(),
|
||||
disabled_by=disabled_by,
|
||||
capabilities=entity.capability_attributes,
|
||||
supported_features=entity.supported_features,
|
||||
device_class=entity.device_class,
|
||||
)
|
||||
|
||||
entity.registry_entry = entry
|
||||
@ -387,10 +390,16 @@ class EntityPlatform:
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if not valid_entity_id(entity.entity_id):
|
||||
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
|
||||
if (
|
||||
entity.entity_id in self.entities
|
||||
or entity.entity_id in self.hass.states.async_entity_ids(self.domain)
|
||||
):
|
||||
|
||||
already_exists = entity.entity_id in self.entities
|
||||
|
||||
if not already_exists:
|
||||
existing = self.hass.states.get(entity.entity_id)
|
||||
|
||||
if existing and not existing.attributes.get("restored"):
|
||||
already_exists = True
|
||||
|
||||
if already_exists:
|
||||
msg = f"Entity id already exists: {entity.entity_id}"
|
||||
if entity.unique_id is not None:
|
||||
msg += ". Platform {} does not generate unique IDs".format(
|
||||
|
@ -15,6 +15,12 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
|
||||
|
||||
import attr
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_SUPPORTED_FEATURES,
|
||||
EVENT_HOMEASSISTANT_START,
|
||||
STATE_UNAVAILABLE,
|
||||
)
|
||||
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
|
||||
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -39,6 +45,8 @@ DISABLED_HASS = "hass"
|
||||
DISABLED_USER = "user"
|
||||
DISABLED_INTEGRATION = "integration"
|
||||
|
||||
ATTR_RESTORED = "restored"
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "core.entity_registry"
|
||||
|
||||
@ -66,6 +74,9 @@ class RegistryEntry:
|
||||
)
|
||||
),
|
||||
)
|
||||
capabilities: Optional[Dict[str, Any]] = attr.ib(default=None)
|
||||
supported_features: int = attr.ib(default=0)
|
||||
device_class: Optional[str] = attr.ib(default=None)
|
||||
domain = attr.ib(type=str, init=False, repr=False)
|
||||
|
||||
@domain.default
|
||||
@ -142,11 +153,17 @@ class EntityRegistry:
|
||||
platform: str,
|
||||
unique_id: str,
|
||||
*,
|
||||
# To influence entity ID generation
|
||||
suggested_object_id: Optional[str] = None,
|
||||
known_object_ids: Optional[Iterable[str]] = None,
|
||||
# To disable an entity if it gets created
|
||||
disabled_by: Optional[str] = None,
|
||||
# Data that we want entry to have
|
||||
config_entry: Optional["ConfigEntry"] = None,
|
||||
device_id: Optional[str] = None,
|
||||
known_object_ids: Optional[Iterable[str]] = None,
|
||||
disabled_by: Optional[str] = None,
|
||||
capabilities: Optional[Dict[str, Any]] = None,
|
||||
supported_features: Optional[int] = None,
|
||||
device_class: Optional[str] = None,
|
||||
) -> RegistryEntry:
|
||||
"""Get entity. Create if it doesn't exist."""
|
||||
config_entry_id = None
|
||||
@ -160,6 +177,9 @@ class EntityRegistry:
|
||||
entity_id,
|
||||
config_entry_id=config_entry_id or _UNDEF,
|
||||
device_id=device_id or _UNDEF,
|
||||
capabilities=capabilities or _UNDEF,
|
||||
supported_features=supported_features or _UNDEF,
|
||||
device_class=device_class or _UNDEF,
|
||||
# When we changed our slugify algorithm, we invalidated some
|
||||
# stored entity IDs with either a __ or ending in _.
|
||||
# Fix introduced in 0.86 (Jan 23, 2019). Next line can be
|
||||
@ -187,6 +207,9 @@ class EntityRegistry:
|
||||
unique_id=unique_id,
|
||||
platform=platform,
|
||||
disabled_by=disabled_by,
|
||||
capabilities=capabilities,
|
||||
supported_features=supported_features or 0,
|
||||
device_class=device_class,
|
||||
)
|
||||
self.entities[entity_id] = entity
|
||||
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
||||
@ -253,6 +276,9 @@ class EntityRegistry:
|
||||
device_id=_UNDEF,
|
||||
new_unique_id=_UNDEF,
|
||||
disabled_by=_UNDEF,
|
||||
capabilities=_UNDEF,
|
||||
supported_features=_UNDEF,
|
||||
device_class=_UNDEF,
|
||||
):
|
||||
"""Private facing update properties method."""
|
||||
old = self.entities[entity_id]
|
||||
@ -264,6 +290,9 @@ class EntityRegistry:
|
||||
("config_entry_id", config_entry_id),
|
||||
("device_id", device_id),
|
||||
("disabled_by", disabled_by),
|
||||
("capabilities", capabilities),
|
||||
("supported_features", supported_features),
|
||||
("device_class", device_class),
|
||||
):
|
||||
if value is not _UNDEF and value != getattr(old, attr_name):
|
||||
changes[attr_name] = value
|
||||
@ -318,6 +347,8 @@ class EntityRegistry:
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the entity registry."""
|
||||
async_setup_entity_restore(self.hass, self)
|
||||
|
||||
data = await self.hass.helpers.storage.async_migrator(
|
||||
self.hass.config.path(PATH_REGISTRY),
|
||||
self._store,
|
||||
@ -336,6 +367,9 @@ class EntityRegistry:
|
||||
platform=entity["platform"],
|
||||
name=entity.get("name"),
|
||||
disabled_by=entity.get("disabled_by"),
|
||||
capabilities=entity.get("capabilities") or {},
|
||||
supported_features=entity.get("supported_features", 0),
|
||||
device_class=entity.get("device_class"),
|
||||
)
|
||||
|
||||
self.entities = entities
|
||||
@ -359,6 +393,9 @@ class EntityRegistry:
|
||||
"platform": entry.platform,
|
||||
"name": entry.name,
|
||||
"disabled_by": entry.disabled_by,
|
||||
"capabilities": entry.capabilities,
|
||||
"supported_features": entry.supported_features,
|
||||
"device_class": entry.device_class,
|
||||
}
|
||||
for entry in self.entities.values()
|
||||
]
|
||||
@ -416,3 +453,53 @@ async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, A
|
||||
{"entity_id": entity_id, **info} for entity_id, info in entities.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_entity_restore(
|
||||
hass: HomeAssistantType, registry: EntityRegistry
|
||||
) -> None:
|
||||
"""Set up the entity restore mechanism."""
|
||||
|
||||
@callback
|
||||
def cleanup_restored_states(event: Event) -> None:
|
||||
"""Clean up restored states."""
|
||||
if event.data["action"] != "remove":
|
||||
return
|
||||
|
||||
state = hass.states.get(event.data["entity_id"])
|
||||
|
||||
if state is None or not state.attributes.get(ATTR_RESTORED):
|
||||
return
|
||||
|
||||
hass.states.async_remove(event.data["entity_id"])
|
||||
|
||||
hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states)
|
||||
|
||||
if hass.is_running:
|
||||
return
|
||||
|
||||
@callback
|
||||
def _write_unavailable_states(_: Event) -> None:
|
||||
"""Make sure state machine contains entry for each registered entity."""
|
||||
states = hass.states
|
||||
existing = set(states.async_entity_ids())
|
||||
|
||||
for entry in registry.entities.values():
|
||||
if entry.entity_id in existing or entry.disabled:
|
||||
continue
|
||||
|
||||
attrs: Dict[str, Any] = {ATTR_RESTORED: True}
|
||||
|
||||
if entry.capabilities:
|
||||
attrs.update(entry.capabilities)
|
||||
|
||||
if entry.supported_features:
|
||||
attrs[ATTR_SUPPORTED_FEATURES] = entry.supported_features
|
||||
|
||||
if entry.device_class:
|
||||
attrs[ATTR_DEVICE_CLASS] = entry.device_class
|
||||
|
||||
states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs)
|
||||
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states)
|
||||
|
@ -906,6 +906,11 @@ class MockEntity(entity.Entity):
|
||||
"""Return the unique ID of the entity."""
|
||||
return self._handle("unique_id")
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the state of the entity."""
|
||||
return self._handle("state")
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
"""Return True if entity is available."""
|
||||
@ -916,6 +921,21 @@ class MockEntity(entity.Entity):
|
||||
"""Info how it links to a device."""
|
||||
return self._handle("device_info")
|
||||
|
||||
@property
|
||||
def device_class(self):
|
||||
"""Info how device should be classified."""
|
||||
return self._handle("device_class")
|
||||
|
||||
@property
|
||||
def capability_attributes(self):
|
||||
"""Info about capabilities."""
|
||||
return self._handle("capability_attributes")
|
||||
|
||||
@property
|
||||
def supported_features(self):
|
||||
"""Info about supported features."""
|
||||
return self._handle("supported_features")
|
||||
|
||||
@property
|
||||
def entity_registry_enabled_default(self):
|
||||
"""Return if the entity should be enabled when first added to the entity registry."""
|
||||
|
@ -130,6 +130,7 @@ async def test_auto_heal_midnight(hass, mock_openzwave):
|
||||
time = utc.localize(datetime(2017, 5, 6, 0, 0, 0))
|
||||
async_fire_time_changed(hass, time)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert network.heal.called
|
||||
assert len(network.heal.mock_calls) == 1
|
||||
|
||||
|
@ -8,7 +8,6 @@ from unittest.mock import Mock, patch
|
||||
import asynctest
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import group
|
||||
from homeassistant.const import ENTITY_MATCH_ALL
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.exceptions import PlatformNotReady
|
||||
@ -285,15 +284,13 @@ async def test_extract_from_service_filter_out_non_existing_entities(hass):
|
||||
async def test_extract_from_service_no_group_expand(hass):
|
||||
"""Test not expanding a group."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
test_group = await group.Group.async_create_group(
|
||||
hass, "test_group", ["light.Ceiling", "light.Kitchen"]
|
||||
)
|
||||
await component.async_add_entities([test_group])
|
||||
await component.async_add_entities([MockEntity(entity_id="group.test_group")])
|
||||
|
||||
call = ha.ServiceCall("test", "service", {"entity_id": ["group.test_group"]})
|
||||
|
||||
extracted = await component.async_extract_from_service(call, expand_group=False)
|
||||
assert extracted == [test_group]
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0].entity_id == "group.test_group"
|
||||
|
||||
|
||||
async def test_setup_dependencies_platform(hass):
|
||||
|
@ -793,3 +793,44 @@ async def test_entity_disabled_by_integration(hass):
|
||||
assert entry_default.disabled_by is None
|
||||
entry_disabled = registry.async_get_or_create(DOMAIN, DOMAIN, "disabled")
|
||||
assert entry_disabled.disabled_by == "integration"
|
||||
|
||||
|
||||
async def test_entity_info_added_to_entity_registry(hass):
|
||||
"""Test entity info is written to entity registry."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass, timedelta(seconds=20))
|
||||
|
||||
entity_default = MockEntity(
|
||||
unique_id="default",
|
||||
capability_attributes={"max": 100},
|
||||
supported_features=5,
|
||||
device_class="mock-device-class",
|
||||
)
|
||||
|
||||
await component.async_add_entities([entity_default])
|
||||
|
||||
registry = await hass.helpers.entity_registry.async_get_registry()
|
||||
|
||||
entry_default = registry.async_get_or_create(DOMAIN, DOMAIN, "default")
|
||||
print(entry_default)
|
||||
assert entry_default.capabilities == {"max": 100}
|
||||
assert entry_default.supported_features == 5
|
||||
assert entry_default.device_class == "mock-device-class"
|
||||
|
||||
|
||||
async def test_override_restored_entities(hass):
|
||||
"""Test that we allow overriding restored entities."""
|
||||
registry = mock_registry(hass)
|
||||
registry.async_get_or_create(
|
||||
"test_domain", "test_domain", "1234", suggested_object_id="world"
|
||||
)
|
||||
|
||||
hass.states.async_set("test_domain.world", "unavailable", {"restored": True})
|
||||
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||||
|
||||
await component.async_add_entities(
|
||||
[MockEntity(unique_id="1234", state="on", entity_id="test_domain.world")], True
|
||||
)
|
||||
|
||||
state = hass.states.get("test_domain.world")
|
||||
assert state.state == "on"
|
||||
|
@ -5,7 +5,8 @@ from unittest.mock import patch
|
||||
import asynctest
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import callback, valid_entity_id
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
||||
from homeassistant.core import CoreState, callback, valid_entity_id
|
||||
from homeassistant.helpers import entity_registry
|
||||
|
||||
from tests.common import MockConfigEntry, flush_store, mock_registry
|
||||
@ -57,6 +58,52 @@ def test_get_or_create_suggested_object_id(registry):
|
||||
assert entry.entity_id == "light.beer"
|
||||
|
||||
|
||||
def test_get_or_create_updates_data(registry):
|
||||
"""Test that we update data in get_or_create."""
|
||||
orig_config_entry = MockConfigEntry(domain="light")
|
||||
|
||||
orig_entry = registry.async_get_or_create(
|
||||
"light",
|
||||
"hue",
|
||||
"5678",
|
||||
config_entry=orig_config_entry,
|
||||
device_id="mock-dev-id",
|
||||
capabilities={"max": 100},
|
||||
supported_features=5,
|
||||
device_class="mock-device-class",
|
||||
disabled_by=entity_registry.DISABLED_HASS,
|
||||
)
|
||||
|
||||
assert orig_entry.config_entry_id == orig_config_entry.entry_id
|
||||
assert orig_entry.device_id == "mock-dev-id"
|
||||
assert orig_entry.capabilities == {"max": 100}
|
||||
assert orig_entry.supported_features == 5
|
||||
assert orig_entry.device_class == "mock-device-class"
|
||||
assert orig_entry.disabled_by == entity_registry.DISABLED_HASS
|
||||
|
||||
new_config_entry = MockConfigEntry(domain="light")
|
||||
|
||||
new_entry = registry.async_get_or_create(
|
||||
"light",
|
||||
"hue",
|
||||
"5678",
|
||||
config_entry=new_config_entry,
|
||||
device_id="new-mock-dev-id",
|
||||
capabilities={"new-max": 100},
|
||||
supported_features=10,
|
||||
device_class="new-mock-device-class",
|
||||
disabled_by=entity_registry.DISABLED_USER,
|
||||
)
|
||||
|
||||
assert new_entry.config_entry_id == new_config_entry.entry_id
|
||||
assert new_entry.device_id == "new-mock-dev-id"
|
||||
assert new_entry.capabilities == {"new-max": 100}
|
||||
assert new_entry.supported_features == 10
|
||||
assert new_entry.device_class == "new-mock-device-class"
|
||||
# Should not be updated
|
||||
assert new_entry.disabled_by == entity_registry.DISABLED_HASS
|
||||
|
||||
|
||||
def test_get_or_create_suggested_object_id_conflict_register(registry):
|
||||
"""Test that we don't generate an entity id that is already registered."""
|
||||
entry = registry.async_get_or_create(
|
||||
@ -91,7 +138,15 @@ async def test_loading_saving_data(hass, registry):
|
||||
|
||||
orig_entry1 = registry.async_get_or_create("light", "hue", "1234")
|
||||
orig_entry2 = registry.async_get_or_create(
|
||||
"light", "hue", "5678", config_entry=mock_config
|
||||
"light",
|
||||
"hue",
|
||||
"5678",
|
||||
device_id="mock-dev-id",
|
||||
config_entry=mock_config,
|
||||
capabilities={"max": 100},
|
||||
supported_features=5,
|
||||
device_class="mock-device-class",
|
||||
disabled_by=entity_registry.DISABLED_HASS,
|
||||
)
|
||||
|
||||
assert len(registry.entities) == 2
|
||||
@ -104,13 +159,17 @@ async def test_loading_saving_data(hass, registry):
|
||||
# Ensure same order
|
||||
assert list(registry.entities) == list(registry2.entities)
|
||||
new_entry1 = registry.async_get_or_create("light", "hue", "1234")
|
||||
new_entry2 = registry.async_get_or_create(
|
||||
"light", "hue", "5678", config_entry=mock_config
|
||||
)
|
||||
new_entry2 = registry.async_get_or_create("light", "hue", "5678")
|
||||
|
||||
assert orig_entry1 == new_entry1
|
||||
assert orig_entry2 == new_entry2
|
||||
|
||||
assert new_entry2.device_id == "mock-dev-id"
|
||||
assert new_entry2.disabled_by == entity_registry.DISABLED_HASS
|
||||
assert new_entry2.capabilities == {"max": 100}
|
||||
assert new_entry2.supported_features == 5
|
||||
assert new_entry2.device_class == "mock-device-class"
|
||||
|
||||
|
||||
def test_generate_entity_considers_registered_entities(registry):
|
||||
"""Test that we don't create entity id that are already registered."""
|
||||
@ -417,3 +476,62 @@ async def test_disabled_by_system_options(registry):
|
||||
"light", "hue", "BBBB", config_entry=mock_config, disabled_by="user"
|
||||
)
|
||||
assert entry2.disabled_by == "user"
|
||||
|
||||
|
||||
async def test_restore_states(hass):
|
||||
"""Test restoring states."""
|
||||
hass.state = CoreState.not_running
|
||||
|
||||
registry = await entity_registry.async_get_registry(hass)
|
||||
|
||||
registry.async_get_or_create(
|
||||
"light", "hue", "1234", suggested_object_id="simple",
|
||||
)
|
||||
# Should not be created
|
||||
registry.async_get_or_create(
|
||||
"light",
|
||||
"hue",
|
||||
"5678",
|
||||
suggested_object_id="disabled",
|
||||
disabled_by=entity_registry.DISABLED_HASS,
|
||||
)
|
||||
registry.async_get_or_create(
|
||||
"light",
|
||||
"hue",
|
||||
"9012",
|
||||
suggested_object_id="all_info_set",
|
||||
capabilities={"max": 100},
|
||||
supported_features=5,
|
||||
device_class="mock-device-class",
|
||||
)
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
simple = hass.states.get("light.simple")
|
||||
assert simple is not None
|
||||
assert simple.state == STATE_UNAVAILABLE
|
||||
assert simple.attributes == {"restored": True}
|
||||
|
||||
disabled = hass.states.get("light.disabled")
|
||||
assert disabled is None
|
||||
|
||||
all_info_set = hass.states.get("light.all_info_set")
|
||||
assert all_info_set is not None
|
||||
assert all_info_set.state == STATE_UNAVAILABLE
|
||||
assert all_info_set.attributes == {
|
||||
"max": 100,
|
||||
"supported_features": 5,
|
||||
"device_class": "mock-device-class",
|
||||
"restored": True,
|
||||
}
|
||||
|
||||
registry.async_remove("light.disabled")
|
||||
registry.async_remove("light.simple")
|
||||
registry.async_remove("light.all_info_set")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("light.simple") is None
|
||||
assert hass.states.get("light.disabled") is None
|
||||
assert hass.states.get("light.all_info_set") is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user