Allow changing entity ID (#15637)

* Allow changing entity ID

* Add support to websocket command

* Address comments

* Error handling
This commit is contained in:
Paulus Schoutsen 2018-07-24 14:12:53 +02:00 committed by GitHub
parent fbeaa57604
commit d9cf8fcfe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 206 additions and 22 deletions

View File

@ -20,6 +20,7 @@ SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('entity_id'): cv.entity_id, vol.Required('entity_id'): cv.entity_id,
# If passed in, we update value. Passing None will remove old value. # If passed in, we update value. Passing None will remove old value.
vol.Optional('name'): vol.Any(str, None), vol.Optional('name'): vol.Any(str, None),
vol.Optional('new_entity_id'): str,
}) })
@ -74,13 +75,28 @@ def websocket_update_entity(hass, connection, msg):
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found')) msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
return return
entry = registry.async_update_entity( changes = {}
msg['entity_id'], name=msg['name'])
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
hass.async_add_job(update_entity()) if 'name' in msg:
changes['name'] = msg['name']
if 'new_entity_id' in msg:
changes['new_entity_id'] = msg['new_entity_id']
try:
if changes:
entry = registry.async_update_entity(
msg['entity_id'], **changes)
except ValueError as err:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_info', str(err)
))
else:
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
hass.async_create_task(update_entity())
@callback @callback

View File

@ -82,6 +82,9 @@ class Entity:
# Name in the entity registry # Name in the entity registry
registry_name = None registry_name = None
# Hold list for functions to call on remove.
_on_remove = None
@property @property
def should_poll(self) -> bool: def should_poll(self) -> bool:
"""Return True if entity has to be polled for state. """Return True if entity has to be polled for state.
@ -324,8 +327,19 @@ class Entity:
if self.parallel_updates: if self.parallel_updates:
self.parallel_updates.release() self.parallel_updates.release()
@callback
def async_on_remove(self, func):
"""Add a function to call when entity removed."""
if self._on_remove is None:
self._on_remove = []
self._on_remove.append(func)
async def async_remove(self): async def async_remove(self):
"""Remove entity from Home Assistant.""" """Remove entity from Home Assistant."""
if self._on_remove is not None:
while self._on_remove:
self._on_remove.pop()()
if self.platform is not None: if self.platform is not None:
await self.platform.async_remove_entity(self.entity_id) await self.platform.async_remove_entity(self.entity_id)
else: else:
@ -335,7 +349,17 @@ class Entity:
def async_registry_updated(self, old, new): def async_registry_updated(self, old, new):
"""Called when the entity registry has been updated.""" """Called when the entity registry has been updated."""
self.registry_name = new.name self.registry_name = new.name
self.async_schedule_update_ha_state()
if new.entity_id == self.entity_id:
self.async_schedule_update_ha_state()
return
async def readd():
"""Remove and add entity again."""
await self.async_remove()
await self.platform.async_add_entities([self])
self.hass.async_create_task(readd())
def __eq__(self, other): def __eq__(self, other):
"""Return the comparison.""" """Return the comparison."""

View File

@ -283,7 +283,7 @@ class EntityPlatform:
entity.entity_id = entry.entity_id entity.entity_id = entry.entity_id
entity.registry_name = entry.name entity.registry_name = entry.name
entry.add_update_listener(entity) entity.async_on_remove(entry.add_update_listener(entity))
# We won't generate an entity ID if the platform has already set one # We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID # We will however make sure that platform cannot pick a registered ID

View File

@ -19,10 +19,10 @@ import weakref
import attr import attr
from ..core import callback, split_entity_id from homeassistant.core import callback, split_entity_id, valid_entity_id
from ..loader import bind_hass from homeassistant.loader import bind_hass
from ..util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
from ..util.yaml import load_yaml, save_yaml from homeassistant.util.yaml import load_yaml, save_yaml
PATH_REGISTRY = 'entity_registry.yaml' PATH_REGISTRY = 'entity_registry.yaml'
DATA_REGISTRY = 'entity_registry' DATA_REGISTRY = 'entity_registry'
@ -63,8 +63,13 @@ class RegistryEntry:
"""Listen for when entry is updated. """Listen for when entry is updated.
Listener: Callback function(old_entry, new_entry) Listener: Callback function(old_entry, new_entry)
Returns function to unlisten.
""" """
self.update_listeners.append(weakref.ref(listener)) weak_listener = weakref.ref(listener)
self.update_listeners.append(weak_listener)
return lambda: self.update_listeners.remove(weak_listener)
class EntityRegistry: class EntityRegistry:
@ -133,13 +138,18 @@ class EntityRegistry:
return entity return entity
@callback @callback
def async_update_entity(self, entity_id, *, name=_UNDEF): def async_update_entity(self, entity_id, *, name=_UNDEF,
new_entity_id=_UNDEF):
"""Update properties of an entity.""" """Update properties of an entity."""
return self._async_update_entity(entity_id, name=name) return self._async_update_entity(
entity_id,
name=name,
new_entity_id=new_entity_id
)
@callback @callback
def _async_update_entity(self, entity_id, *, name=_UNDEF, def _async_update_entity(self, entity_id, *, name=_UNDEF,
config_entry_id=_UNDEF): config_entry_id=_UNDEF, new_entity_id=_UNDEF):
"""Private facing update properties method.""" """Private facing update properties method."""
old = self.entities[entity_id] old = self.entities[entity_id]
@ -152,6 +162,20 @@ class EntityRegistry:
config_entry_id != old.config_entry_id): config_entry_id != old.config_entry_id):
changes['config_entry_id'] = config_entry_id changes['config_entry_id'] = config_entry_id
if new_entity_id is not _UNDEF and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id):
raise ValueError('Entity is already registered')
if not valid_entity_id(new_entity_id):
raise ValueError('Invalid entity ID')
if (split_entity_id(new_entity_id)[0] !=
split_entity_id(entity_id)[0]):
raise ValueError('New entity ID should be same domain')
self.entities.pop(entity_id)
entity_id = changes['entity_id'] = new_entity_id
if not changes: if not changes:
return old return old

View File

@ -54,8 +54,8 @@ async def test_get_entity(hass, client):
} }
async def test_update_entity(hass, client): async def test_update_entity_name(hass, client):
"""Test get entry.""" """Test updating entity name."""
mock_registry(hass, { mock_registry(hass, {
'test_domain.world': RegistryEntry( 'test_domain.world': RegistryEntry(
entity_id='test_domain.world', entity_id='test_domain.world',
@ -92,7 +92,7 @@ async def test_update_entity(hass, client):
async def test_update_entity_no_changes(hass, client): async def test_update_entity_no_changes(hass, client):
"""Test get entry.""" """Test update entity with no changes."""
mock_registry(hass, { mock_registry(hass, {
'test_domain.world': RegistryEntry( 'test_domain.world': RegistryEntry(
entity_id='test_domain.world', entity_id='test_domain.world',
@ -129,7 +129,7 @@ async def test_update_entity_no_changes(hass, client):
async def test_get_nonexisting_entity(client): async def test_get_nonexisting_entity(client):
"""Test get entry.""" """Test get entry with nonexisting entity."""
await client.send_json({ await client.send_json({
'id': 6, 'id': 6,
'type': 'config/entity_registry/get', 'type': 'config/entity_registry/get',
@ -141,7 +141,7 @@ async def test_get_nonexisting_entity(client):
async def test_update_nonexisting_entity(client): async def test_update_nonexisting_entity(client):
"""Test get entry.""" """Test update a nonexisting entity."""
await client.send_json({ await client.send_json({
'id': 6, 'id': 6,
'type': 'config/entity_registry/update', 'type': 'config/entity_registry/update',
@ -151,3 +151,37 @@ async def test_update_nonexisting_entity(client):
msg = await client.receive_json() msg = await client.receive_json()
assert not msg['success'] assert not msg['success']
async def test_update_entity_id(hass, client):
"""Test update entity id."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
assert hass.states.get('test_domain.world') is not None
await client.send_json({
'id': 6,
'type': 'config/entity_registry/update',
'entity_id': 'test_domain.world',
'new_entity_id': 'test_domain.planet',
})
msg = await client.receive_json()
assert msg['result'] == {
'entity_id': 'test_domain.planet',
'name': None
}
assert hass.states.get('test_domain.world') is None
assert hass.states.get('test_domain.planet') is not None

View File

@ -400,3 +400,15 @@ def test_async_remove_no_platform(hass):
assert len(hass.states.async_entity_ids()) == 1 assert len(hass.states.async_entity_ids()) == 1
yield from ent.async_remove() yield from ent.async_remove()
assert len(hass.states.async_entity_ids()) == 0 assert len(hass.states.async_entity_ids()) == 0
async def test_async_remove_runs_callbacks(hass):
"""Test async_remove method when no platform set."""
result = []
ent = entity.Entity()
ent.hass = hass
ent.entity_id = 'test.test'
ent.async_on_remove(lambda: result.append(1))
await ent.async_remove()
assert len(result) == 1

View File

@ -5,6 +5,8 @@ import unittest
from unittest.mock import patch, Mock, MagicMock from unittest.mock import patch, Mock, MagicMock
from datetime import timedelta from datetime import timedelta
import pytest
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
import homeassistant.loader as loader import homeassistant.loader as loader
from homeassistant.helpers.entity import generate_entity_id from homeassistant.helpers.entity import generate_entity_id
@ -487,7 +489,7 @@ def test_registry_respect_entity_disabled(hass):
assert hass.states.async_entity_ids() == [] assert hass.states.async_entity_ids() == []
async def test_entity_registry_updates(hass): async def test_entity_registry_updates_name(hass):
"""Test that updates on the entity registry update platform entities.""" """Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, { registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry( 'test_domain.world': entity_registry.RegistryEntry(
@ -602,3 +604,75 @@ def test_not_fails_with_adding_empty_entities_(hass):
yield from component.async_add_entities([]) yield from component.async_add_entities([])
assert len(hass.states.async_entity_ids()) == 0 assert len(hass.states.async_entity_ids()) == 0
async def test_entity_registry_updates_entity_id(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='Some name'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'Some name'
registry.async_update_entity('test_domain.world',
new_entity_id='test_domain.planet')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert hass.states.get('test_domain.world') is None
assert hass.states.get('test_domain.planet') is not None
async def test_entity_registry_updates_invalid_entity_id(hass):
"""Test that we can't update to an invalid entity id."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='Some name'
),
'test_domain.existing': entity_registry.RegistryEntry(
entity_id='test_domain.existing',
unique_id='5678',
platform='test_platform',
),
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])
state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'Some name'
with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='test_domain.existing')
with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='invalid_entity_id')
with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='diff_domain.world')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert hass.states.get('test_domain.world') is not None
assert hass.states.get('invalid_entity_id') is None
assert hass.states.get('diff_domain.world') is None