Storage entity registry (#16018)

* Split out storage delayed write

* Update code using delayed save

* Fix tests

* Fix typing test

* Add callback decorator

* Migrate entity registry to storage helper

* Make double loading protection easier

* Lint

* Fix tests

* Ordered Dict
This commit is contained in:
Paulus Schoutsen 2018-08-18 13:34:33 +02:00 committed by Pascal Vizeli
parent ef193b0f64
commit 8ec550d6e0
7 changed files with 167 additions and 149 deletions

View File

@ -6,15 +6,10 @@ identified by their domain, platform and a unique id provided by that platform.
The Entity Registry will persist itself 10 seconds after a new entity is The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the registered. Registering a new entity while a timer is in progress resets the
timer. timer.
After initializing, call EntityRegistry.async_ensure_loaded to load the data
from disk.
""" """
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
import os
import weakref import weakref
import attr import attr
@ -22,7 +17,7 @@ import attr
from homeassistant.core import callback, split_entity_id, valid_entity_id from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.yaml import load_yaml, save_yaml from homeassistant.util.yaml import load_yaml
PATH_REGISTRY = 'entity_registry.yaml' PATH_REGISTRY = 'entity_registry.yaml'
DATA_REGISTRY = 'entity_registry' DATA_REGISTRY = 'entity_registry'
@ -32,6 +27,9 @@ _UNDEF = object()
DISABLED_HASS = 'hass' DISABLED_HASS = 'hass'
DISABLED_USER = 'user' DISABLED_USER = 'user'
STORAGE_VERSION = 1
STORAGE_KEY = 'core.entity_registry'
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class RegistryEntry: class RegistryEntry:
@ -79,8 +77,7 @@ class EntityRegistry:
"""Initialize the registry.""" """Initialize the registry."""
self.hass = hass self.hass = hass
self.entities = None self.entities = None
self._load_task = None self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._sched_save = None
@callback @callback
def async_is_registered(self, entity_id): def async_is_registered(self, entity_id):
@ -199,71 +196,72 @@ class EntityRegistry:
return new return new
async def async_ensure_loaded(self): async def async_load(self):
"""Load the registry from disk."""
if self.entities is not None:
return
if self._load_task is None:
self._load_task = self.hass.async_add_job(self._async_load)
await self._load_task
async def _async_load(self):
"""Load the entity registry.""" """Load the entity registry."""
path = self.hass.config.path(PATH_REGISTRY) data = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_REGISTRY), self._store,
old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate
)
entities = OrderedDict() entities = OrderedDict()
if os.path.isfile(path): if data is not None:
data = await self.hass.async_add_job(load_yaml, path) for entity in data['entities']:
entities[entity['entity_id']] = RegistryEntry(
for entity_id, info in data.items(): entity_id=entity['entity_id'],
entities[entity_id] = RegistryEntry( config_entry_id=entity.get('config_entry_id'),
entity_id=entity_id, unique_id=entity['unique_id'],
config_entry_id=info.get('config_entry_id'), platform=entity['platform'],
unique_id=info['unique_id'], name=entity.get('name'),
platform=info['platform'], disabled_by=entity.get('disabled_by')
name=info.get('name'),
disabled_by=info.get('disabled_by')
) )
self.entities = entities self.entities = entities
self._load_task = None
@callback @callback
def async_schedule_save(self): def async_schedule_save(self):
"""Schedule saving the entity registry.""" """Schedule saving the entity registry."""
if self._sched_save is not None: self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
self._sched_save.cancel()
self._sched_save = self.hass.loop.call_later( @callback
SAVE_DELAY, self.hass.async_add_job, self._async_save def _data_to_save(self):
) """Data of entity registry to store in a file."""
data = {}
async def _async_save(self): data['entities'] = [
"""Save the entity registry to a file.""" {
self._sched_save = None 'entity_id': entry.entity_id,
data = OrderedDict()
for entry in self.entities.values():
data[entry.entity_id] = {
'config_entry_id': entry.config_entry_id, 'config_entry_id': entry.config_entry_id,
'unique_id': entry.unique_id, 'unique_id': entry.unique_id,
'platform': entry.platform, 'platform': entry.platform,
'name': entry.name, 'name': entry.name,
} } for entry in self.entities.values()
]
await self.hass.async_add_job( return data
save_yaml, self.hass.config.path(PATH_REGISTRY), data)
@bind_hass @bind_hass
async def async_get_registry(hass) -> EntityRegistry: async def async_get_registry(hass) -> EntityRegistry:
"""Return entity registry instance.""" """Return entity registry instance."""
registry = hass.data.get(DATA_REGISTRY) task = hass.data.get(DATA_REGISTRY)
if registry is None: if task is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass) async def _load_reg():
registry = EntityRegistry(hass)
await registry.async_load()
return registry
await registry.async_ensure_loaded() task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
return registry
return await task
async def _async_migrate(entities):
"""Migrate the YAML config file to storage helper format."""
return {
'entities': [
{'entity_id': entity_id, **info}
for entity_id, info in entities.items()
]
}

View File

@ -15,7 +15,9 @@ _LOGGER = logging.getLogger(__name__)
@bind_hass @bind_hass
async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None): async def async_migrator(hass, old_path, store, *,
old_conf_load_func=json.load_json,
old_conf_migrate_func=None):
"""Helper function to migrate old data to a store and then load data. """Helper function to migrate old data to a store and then load data.
async def old_conf_migrate_func(old_data) async def old_conf_migrate_func(old_data)
@ -25,7 +27,7 @@ async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None):
if not os.path.isfile(old_path): if not os.path.isfile(old_path):
return None return None
return json.load_json(old_path) return old_conf_load_func(old_path)
config = await hass.async_add_executor_job(load_old_config) config = await hass.async_add_executor_job(load_old_config)
@ -52,7 +54,7 @@ class Store:
self._data = None self._data = None
self._unsub_delay_listener = None self._unsub_delay_listener = None
self._unsub_stop_listener = None self._unsub_stop_listener = None
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock(loop=hass.loop)
self._load_task = None self._load_task = None
@property @property

View File

@ -307,7 +307,12 @@ def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry.""" """Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass) registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or {} registry.entities = mock_entries or {}
hass.data[entity_registry.DATA_REGISTRY] = registry
async def _get_reg():
return registry
hass.data[entity_registry.DATA_REGISTRY] = \
hass.loop.create_task(_get_reg())
return registry return registry

View File

@ -14,7 +14,7 @@ from homeassistant.components import light
from homeassistant.helpers.intent import IntentHandleError from homeassistant.helpers.intent import IntentHandleError
from tests.common import ( from tests.common import (
async_mock_service, mock_service, get_test_home_assistant) async_mock_service, mock_service, get_test_home_assistant, mock_storage)
class TestLight(unittest.TestCase): class TestLight(unittest.TestCase):
@ -333,10 +333,11 @@ class TestLight(unittest.TestCase):
"group.all_lights.default,.4,.6,99\n" "group.all_lights.default,.4,.6,99\n"
with mock.patch('os.path.isfile', side_effect=_mock_isfile): with mock.patch('os.path.isfile', side_effect=_mock_isfile):
with mock.patch('builtins.open', side_effect=_mock_open): with mock.patch('builtins.open', side_effect=_mock_open):
self.assertTrue(setup_component( with mock_storage():
self.hass, light.DOMAIN, self.assertTrue(setup_component(
{light.DOMAIN: {CONF_PLATFORM: 'test'}} self.hass, light.DOMAIN,
)) {light.DOMAIN: {CONF_PLATFORM: 'test'}}
))
dev, _, _ = platform.DEVICES dev, _, _ = platform.DEVICES
light.turn_on(self.hass, dev.entity_id) light.turn_on(self.hass, dev.entity_id)
@ -371,10 +372,11 @@ class TestLight(unittest.TestCase):
"light.ceiling_2.default,.6,.6,100\n" "light.ceiling_2.default,.6,.6,100\n"
with mock.patch('os.path.isfile', side_effect=_mock_isfile): with mock.patch('os.path.isfile', side_effect=_mock_isfile):
with mock.patch('builtins.open', side_effect=_mock_open): with mock.patch('builtins.open', side_effect=_mock_open):
self.assertTrue(setup_component( with mock_storage():
self.hass, light.DOMAIN, self.assertTrue(setup_component(
{light.DOMAIN: {CONF_PLATFORM: 'test'}} self.hass, light.DOMAIN,
)) {light.DOMAIN: {CONF_PLATFORM: 'test'}}
))
dev = next(filter(lambda x: x.entity_id == 'light.ceiling_2', dev = next(filter(lambda x: x.entity_id == 'light.ceiling_2',
platform.DEVICES)) platform.DEVICES))

View File

@ -5,13 +5,14 @@ from datetime import timedelta, datetime
from unittest.mock import patch from unittest.mock import patch
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.setup import setup_component from homeassistant.setup import setup_component, async_setup_component
import homeassistant.components.sensor as sensor import homeassistant.components.sensor as sensor
from homeassistant.const import EVENT_STATE_CHANGED, STATE_UNAVAILABLE from homeassistant.const import EVENT_STATE_CHANGED, STATE_UNAVAILABLE
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import mock_mqtt_component, fire_mqtt_message, \ from tests.common import mock_mqtt_component, fire_mqtt_message, \
assert_setup_component assert_setup_component, async_fire_mqtt_message, \
async_mock_mqtt_component
from tests.common import get_test_home_assistant, mock_component from tests.common import get_test_home_assistant, mock_component
@ -331,27 +332,6 @@ class TestSensorMQTT(unittest.TestCase):
state.attributes.get('val')) state.attributes.get('val'))
self.assertEqual('100', state.state) self.assertEqual('100', state.state)
def test_unique_id(self):
"""Test unique id option only creates one sensor per unique_id."""
assert setup_component(self.hass, sensor.DOMAIN, {
sensor.DOMAIN: [{
'platform': 'mqtt',
'name': 'Test 1',
'state_topic': 'test-topic',
'unique_id': 'TOTALLY_UNIQUE'
}, {
'platform': 'mqtt',
'name': 'Test 2',
'state_topic': 'test-topic',
'unique_id': 'TOTALLY_UNIQUE'
}]
})
fire_mqtt_message(self.hass, 'test-topic', 'payload')
self.hass.block_till_done()
assert len(self.hass.states.all()) == 1
def test_invalid_device_class(self): def test_invalid_device_class(self):
"""Test device_class option with invalid value.""" """Test device_class option with invalid value."""
with assert_setup_component(0): with assert_setup_component(0):
@ -384,3 +364,26 @@ class TestSensorMQTT(unittest.TestCase):
assert state.attributes['device_class'] == 'temperature' assert state.attributes['device_class'] == 'temperature'
state = self.hass.states.get('sensor.test_2') state = self.hass.states.get('sensor.test_2')
assert 'device_class' not in state.attributes assert 'device_class' not in state.attributes
async def test_unique_id(hass):
"""Test unique id option only creates one sensor per unique_id."""
await async_mock_mqtt_component(hass)
assert await async_setup_component(hass, sensor.DOMAIN, {
sensor.DOMAIN: [{
'platform': 'mqtt',
'name': 'Test 1',
'state_topic': 'test-topic',
'unique_id': 'TOTALLY_UNIQUE'
}, {
'platform': 'mqtt',
'name': 'Test 2',
'state_topic': 'test-topic',
'unique_id': 'TOTALLY_UNIQUE'
}]
})
async_fire_mqtt_message(hass, 'test-topic', 'payload')
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 1

View File

@ -1,6 +1,6 @@
"""Tests for the Entity Registry.""" """Tests for the Entity Registry."""
import asyncio import asyncio
from unittest.mock import patch, mock_open from unittest.mock import patch
import pytest import pytest
@ -61,29 +61,13 @@ def test_get_or_create_suggested_object_id_conflict_existing(hass, registry):
@asyncio.coroutine @asyncio.coroutine
def test_create_triggers_save(hass, registry): def test_create_triggers_save(hass, registry):
"""Test that registering entry triggers a save.""" """Test that registering entry triggers a save."""
with patch.object(hass.loop, 'call_later') as mock_call_later: with patch.object(registry, 'async_schedule_save') as mock_schedule_save:
registry.async_get_or_create('light', 'hue', '1234') registry.async_get_or_create('light', 'hue', '1234')
assert len(mock_call_later.mock_calls) == 1 assert len(mock_schedule_save.mock_calls) == 1
@asyncio.coroutine async def test_loading_saving_data(hass, registry):
def test_save_timer_reset_on_subsequent_save(hass, registry):
"""Test we reset the save timer on a new create."""
with patch.object(hass.loop, 'call_later') as mock_call_later:
registry.async_get_or_create('light', 'hue', '1234')
assert len(mock_call_later.mock_calls) == 1
with patch.object(hass.loop, 'call_later') as mock_call_later_2:
registry.async_get_or_create('light', 'hue', '5678')
assert len(mock_call_later().cancel.mock_calls) == 1
assert len(mock_call_later_2.mock_calls) == 1
@asyncio.coroutine
def test_loading_saving_data(hass, registry):
"""Test that we load/save data correctly.""" """Test that we load/save data correctly."""
orig_entry1 = registry.async_get_or_create('light', 'hue', '1234') orig_entry1 = registry.async_get_or_create('light', 'hue', '1234')
orig_entry2 = registry.async_get_or_create( orig_entry2 = registry.async_get_or_create(
@ -91,18 +75,11 @@ def test_loading_saving_data(hass, registry):
assert len(registry.entities) == 2 assert len(registry.entities) == 2
with patch(YAML__OPEN_PATH, mock_open(), create=True) as mock_write:
yield from registry._async_save()
# Mock open calls are: open file, context enter, write, context leave
written = mock_write.mock_calls[2][1][0]
# Now load written data in new registry # Now load written data in new registry
registry2 = entity_registry.EntityRegistry(hass) registry2 = entity_registry.EntityRegistry(hass)
registry2._store = registry._store
with patch('os.path.isfile', return_value=True), \ await registry2.async_load()
patch(YAML__OPEN_PATH, mock_open(read_data=written), create=True):
yield from registry2._async_load()
# Ensure same order # Ensure same order
assert list(registry.entities) == list(registry2.entities) assert list(registry.entities) == list(registry2.entities)
@ -139,32 +116,37 @@ def test_is_registered(registry):
assert not registry.async_is_registered('light.non_existing') assert not registry.async_is_registered('light.non_existing')
@asyncio.coroutine async def test_loading_extra_values(hass, hass_storage):
def test_loading_extra_values(hass):
"""Test we load extra data from the registry.""" """Test we load extra data from the registry."""
written = """ hass_storage[entity_registry.STORAGE_KEY] = {
test.named: 'version': entity_registry.STORAGE_VERSION,
platform: super_platform 'data': {
unique_id: with-name 'entities': [
name: registry override {
test.no_name: 'entity_id': 'test.named',
platform: super_platform 'platform': 'super_platform',
unique_id: without-name 'unique_id': 'with-name',
test.disabled_user: 'name': 'registry override',
platform: super_platform }, {
unique_id: disabled-user 'entity_id': 'test.no_name',
disabled_by: user 'platform': 'super_platform',
test.disabled_hass: 'unique_id': 'without-name',
platform: super_platform }, {
unique_id: disabled-hass 'entity_id': 'test.disabled_user',
disabled_by: hass 'platform': 'super_platform',
""" 'unique_id': 'disabled-user',
'disabled_by': 'user',
}, {
'entity_id': 'test.disabled_hass',
'platform': 'super_platform',
'unique_id': 'disabled-hass',
'disabled_by': 'hass',
}
]
}
}
registry = entity_registry.EntityRegistry(hass) registry = await entity_registry.async_get_registry(hass)
with patch('os.path.isfile', return_value=True), \
patch(YAML__OPEN_PATH, mock_open(read_data=written), create=True):
yield from registry._async_load()
entry_with_name = registry.async_get_or_create( entry_with_name = registry.async_get_or_create(
'test', 'super_platform', 'with-name') 'test', 'super_platform', 'with-name')
@ -202,3 +184,31 @@ async def test_updating_config_entry_id(registry):
'light', 'hue', '5678', config_entry_id='mock-id-2') 'light', 'hue', '5678', config_entry_id='mock-id-2')
assert entry.entity_id == entry2.entity_id assert entry.entity_id == entry2.entity_id
assert entry2.config_entry_id == 'mock-id-2' assert entry2.config_entry_id == 'mock-id-2'
async def test_migration(hass):
"""Test migration from old data to new."""
old_conf = {
'light.kitchen': {
'config_entry_id': 'test-config-id',
'unique_id': 'test-unique',
'platform': 'test-platform',
'name': 'Test Name',
'disabled_by': 'hass',
}
}
with patch('os.path.isfile', return_value=True), patch('os.remove'), \
patch('homeassistant.helpers.entity_registry.load_yaml',
return_value=old_conf):
registry = await entity_registry.async_get_registry(hass)
assert registry.async_is_registered('light.kitchen')
entry = registry.async_get_or_create(
domain='light',
platform='test-platform',
unique_id='test-unique',
config_entry_id='test-config-id',
)
assert entry.name == 'Test Name'
assert entry.disabled_by == 'hass'
assert entry.config_entry_id == 'test-config-id'

View File

@ -141,11 +141,10 @@ async def test_migrator_no_existing_config(hass, store, hass_storage):
async def test_migrator_existing_config(hass, store, hass_storage): async def test_migrator_existing_config(hass, store, hass_storage):
"""Test migrating existing config.""" """Test migrating existing config."""
with patch('os.path.isfile', return_value=True), \ with patch('os.path.isfile', return_value=True), \
patch('os.remove') as mock_remove, \ patch('os.remove') as mock_remove:
patch('homeassistant.util.json.load_json',
return_value={'old': 'config'}):
data = await storage.async_migrator( data = await storage.async_migrator(
hass, 'old-path', store) hass, 'old-path', store,
old_conf_load_func=lambda _: {'old': 'config'})
assert len(mock_remove.mock_calls) == 1 assert len(mock_remove.mock_calls) == 1
assert data == {'old': 'config'} assert data == {'old': 'config'}
@ -163,12 +162,11 @@ async def test_migrator_transforming_config(hass, store, hass_storage):
return {'new': old_config['old']} return {'new': old_config['old']}
with patch('os.path.isfile', return_value=True), \ with patch('os.path.isfile', return_value=True), \
patch('os.remove') as mock_remove, \ patch('os.remove') as mock_remove:
patch('homeassistant.util.json.load_json',
return_value={'old': 'config'}):
data = await storage.async_migrator( data = await storage.async_migrator(
hass, 'old-path', store, hass, 'old-path', store,
old_conf_migrate_func=old_conf_migrate_func) old_conf_migrate_func=old_conf_migrate_func,
old_conf_load_func=lambda _: {'old': 'config'})
assert len(mock_remove.mock_calls) == 1 assert len(mock_remove.mock_calls) == 1
assert data == {'new': 'config'} assert data == {'new': 'config'}