diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 2fa64ff8680..285d4cbd23a 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -6,15 +6,10 @@ identified by their domain, platform and a unique id provided by that platform. The Entity Registry will persist itself 10 seconds after a new entity is registered. Registering a new entity while a timer is in progress resets the timer. - -After initializing, call EntityRegistry.async_ensure_loaded to load the data -from disk. """ - from collections import OrderedDict from itertools import chain import logging -import os import weakref import attr @@ -22,7 +17,7 @@ import attr from homeassistant.core import callback, split_entity_id, valid_entity_id from homeassistant.loader import bind_hass from homeassistant.util import ensure_unique_string, slugify -from homeassistant.util.yaml import load_yaml, save_yaml +from homeassistant.util.yaml import load_yaml PATH_REGISTRY = 'entity_registry.yaml' DATA_REGISTRY = 'entity_registry' @@ -32,6 +27,9 @@ _UNDEF = object() DISABLED_HASS = 'hass' DISABLED_USER = 'user' +STORAGE_VERSION = 1 +STORAGE_KEY = 'core.entity_registry' + @attr.s(slots=True, frozen=True) class RegistryEntry: @@ -79,8 +77,7 @@ class EntityRegistry: """Initialize the registry.""" self.hass = hass self.entities = None - self._load_task = None - self._sched_save = None + self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) @callback def async_is_registered(self, entity_id): @@ -199,71 +196,72 @@ class EntityRegistry: return new - async def async_ensure_loaded(self): - """Load the registry from disk.""" - if self.entities is not None: - return - - if self._load_task is None: - self._load_task = self.hass.async_add_job(self._async_load) - - await self._load_task - - async def _async_load(self): + async def async_load(self): """Load the entity registry.""" - path = self.hass.config.path(PATH_REGISTRY) + data = await self.hass.helpers.storage.async_migrator( + self.hass.config.path(PATH_REGISTRY), self._store, + old_conf_load_func=load_yaml, + old_conf_migrate_func=_async_migrate + ) entities = OrderedDict() - if os.path.isfile(path): - data = await self.hass.async_add_job(load_yaml, path) - - for entity_id, info in data.items(): - entities[entity_id] = RegistryEntry( - entity_id=entity_id, - config_entry_id=info.get('config_entry_id'), - unique_id=info['unique_id'], - platform=info['platform'], - name=info.get('name'), - disabled_by=info.get('disabled_by') + if data is not None: + for entity in data['entities']: + entities[entity['entity_id']] = RegistryEntry( + entity_id=entity['entity_id'], + config_entry_id=entity.get('config_entry_id'), + unique_id=entity['unique_id'], + platform=entity['platform'], + name=entity.get('name'), + disabled_by=entity.get('disabled_by') ) self.entities = entities - self._load_task = None @callback def async_schedule_save(self): """Schedule saving the entity registry.""" - if self._sched_save is not None: - self._sched_save.cancel() + self._store.async_delay_save(self._data_to_save, SAVE_DELAY) - self._sched_save = self.hass.loop.call_later( - SAVE_DELAY, self.hass.async_add_job, self._async_save - ) + @callback + def _data_to_save(self): + """Data of entity registry to store in a file.""" + data = {} - async def _async_save(self): - """Save the entity registry to a file.""" - self._sched_save = None - data = OrderedDict() - - for entry in self.entities.values(): - data[entry.entity_id] = { + data['entities'] = [ + { + 'entity_id': entry.entity_id, 'config_entry_id': entry.config_entry_id, 'unique_id': entry.unique_id, 'platform': entry.platform, 'name': entry.name, - } + } for entry in self.entities.values() + ] - await self.hass.async_add_job( - save_yaml, self.hass.config.path(PATH_REGISTRY), data) + return data @bind_hass async def async_get_registry(hass) -> EntityRegistry: """Return entity registry instance.""" - registry = hass.data.get(DATA_REGISTRY) + task = hass.data.get(DATA_REGISTRY) - if registry is None: - registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass) + if task is None: + async def _load_reg(): + registry = EntityRegistry(hass) + await registry.async_load() + return registry - await registry.async_ensure_loaded() - return registry + task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg()) + + return await task + + +async def _async_migrate(entities): + """Migrate the YAML config file to storage helper format.""" + return { + 'entities': [ + {'entity_id': entity_id, **info} + for entity_id, info in entities.items() + ] + } diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index 47d182d9a7c..8931341f1a2 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -15,7 +15,9 @@ _LOGGER = logging.getLogger(__name__) @bind_hass -async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None): +async def async_migrator(hass, old_path, store, *, + old_conf_load_func=json.load_json, + old_conf_migrate_func=None): """Helper function to migrate old data to a store and then load data. async def old_conf_migrate_func(old_data) @@ -25,7 +27,7 @@ async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None): if not os.path.isfile(old_path): return None - return json.load_json(old_path) + return old_conf_load_func(old_path) config = await hass.async_add_executor_job(load_old_config) @@ -52,7 +54,7 @@ class Store: self._data = None self._unsub_delay_listener = None self._unsub_stop_listener = None - self._write_lock = asyncio.Lock() + self._write_lock = asyncio.Lock(loop=hass.loop) self._load_task = None @property diff --git a/tests/common.py b/tests/common.py index 81e4774ccd4..e7445751783 100644 --- a/tests/common.py +++ b/tests/common.py @@ -307,7 +307,12 @@ def mock_registry(hass, mock_entries=None): """Mock the Entity Registry.""" registry = entity_registry.EntityRegistry(hass) registry.entities = mock_entries or {} - hass.data[entity_registry.DATA_REGISTRY] = registry + + async def _get_reg(): + return registry + + hass.data[entity_registry.DATA_REGISTRY] = \ + hass.loop.create_task(_get_reg()) return registry diff --git a/tests/components/light/test_init.py b/tests/components/light/test_init.py index 4d779eef461..0f73c5a38c6 100644 --- a/tests/components/light/test_init.py +++ b/tests/components/light/test_init.py @@ -14,7 +14,7 @@ from homeassistant.components import light from homeassistant.helpers.intent import IntentHandleError from tests.common import ( - async_mock_service, mock_service, get_test_home_assistant) + async_mock_service, mock_service, get_test_home_assistant, mock_storage) class TestLight(unittest.TestCase): @@ -333,10 +333,11 @@ class TestLight(unittest.TestCase): "group.all_lights.default,.4,.6,99\n" with mock.patch('os.path.isfile', side_effect=_mock_isfile): with mock.patch('builtins.open', side_effect=_mock_open): - self.assertTrue(setup_component( - self.hass, light.DOMAIN, - {light.DOMAIN: {CONF_PLATFORM: 'test'}} - )) + with mock_storage(): + self.assertTrue(setup_component( + self.hass, light.DOMAIN, + {light.DOMAIN: {CONF_PLATFORM: 'test'}} + )) dev, _, _ = platform.DEVICES light.turn_on(self.hass, dev.entity_id) @@ -371,10 +372,11 @@ class TestLight(unittest.TestCase): "light.ceiling_2.default,.6,.6,100\n" with mock.patch('os.path.isfile', side_effect=_mock_isfile): with mock.patch('builtins.open', side_effect=_mock_open): - self.assertTrue(setup_component( - self.hass, light.DOMAIN, - {light.DOMAIN: {CONF_PLATFORM: 'test'}} - )) + with mock_storage(): + self.assertTrue(setup_component( + self.hass, light.DOMAIN, + {light.DOMAIN: {CONF_PLATFORM: 'test'}} + )) dev = next(filter(lambda x: x.entity_id == 'light.ceiling_2', platform.DEVICES)) diff --git a/tests/components/sensor/test_mqtt.py b/tests/components/sensor/test_mqtt.py index 2583f52b3d2..234afff3418 100644 --- a/tests/components/sensor/test_mqtt.py +++ b/tests/components/sensor/test_mqtt.py @@ -5,13 +5,14 @@ from datetime import timedelta, datetime from unittest.mock import patch import homeassistant.core as ha -from homeassistant.setup import setup_component +from homeassistant.setup import setup_component, async_setup_component import homeassistant.components.sensor as sensor from homeassistant.const import EVENT_STATE_CHANGED, STATE_UNAVAILABLE import homeassistant.util.dt as dt_util from tests.common import mock_mqtt_component, fire_mqtt_message, \ - assert_setup_component + assert_setup_component, async_fire_mqtt_message, \ + async_mock_mqtt_component from tests.common import get_test_home_assistant, mock_component @@ -331,27 +332,6 @@ class TestSensorMQTT(unittest.TestCase): state.attributes.get('val')) self.assertEqual('100', state.state) - def test_unique_id(self): - """Test unique id option only creates one sensor per unique_id.""" - assert setup_component(self.hass, sensor.DOMAIN, { - sensor.DOMAIN: [{ - 'platform': 'mqtt', - 'name': 'Test 1', - 'state_topic': 'test-topic', - 'unique_id': 'TOTALLY_UNIQUE' - }, { - 'platform': 'mqtt', - 'name': 'Test 2', - 'state_topic': 'test-topic', - 'unique_id': 'TOTALLY_UNIQUE' - }] - }) - - fire_mqtt_message(self.hass, 'test-topic', 'payload') - self.hass.block_till_done() - - assert len(self.hass.states.all()) == 1 - def test_invalid_device_class(self): """Test device_class option with invalid value.""" with assert_setup_component(0): @@ -384,3 +364,26 @@ class TestSensorMQTT(unittest.TestCase): assert state.attributes['device_class'] == 'temperature' state = self.hass.states.get('sensor.test_2') assert 'device_class' not in state.attributes + + +async def test_unique_id(hass): + """Test unique id option only creates one sensor per unique_id.""" + await async_mock_mqtt_component(hass) + assert await async_setup_component(hass, sensor.DOMAIN, { + sensor.DOMAIN: [{ + 'platform': 'mqtt', + 'name': 'Test 1', + 'state_topic': 'test-topic', + 'unique_id': 'TOTALLY_UNIQUE' + }, { + 'platform': 'mqtt', + 'name': 'Test 2', + 'state_topic': 'test-topic', + 'unique_id': 'TOTALLY_UNIQUE' + }] + }) + + async_fire_mqtt_message(hass, 'test-topic', 'payload') + await hass.async_block_till_done() + + assert len(hass.states.async_all()) == 1 diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 5a9efd5c041..d0c088a6f69 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -1,6 +1,6 @@ """Tests for the Entity Registry.""" import asyncio -from unittest.mock import patch, mock_open +from unittest.mock import patch import pytest @@ -61,29 +61,13 @@ def test_get_or_create_suggested_object_id_conflict_existing(hass, registry): @asyncio.coroutine def test_create_triggers_save(hass, registry): """Test that registering entry triggers a save.""" - with patch.object(hass.loop, 'call_later') as mock_call_later: + with patch.object(registry, 'async_schedule_save') as mock_schedule_save: registry.async_get_or_create('light', 'hue', '1234') - assert len(mock_call_later.mock_calls) == 1 + assert len(mock_schedule_save.mock_calls) == 1 -@asyncio.coroutine -def test_save_timer_reset_on_subsequent_save(hass, registry): - """Test we reset the save timer on a new create.""" - with patch.object(hass.loop, 'call_later') as mock_call_later: - registry.async_get_or_create('light', 'hue', '1234') - - assert len(mock_call_later.mock_calls) == 1 - - with patch.object(hass.loop, 'call_later') as mock_call_later_2: - registry.async_get_or_create('light', 'hue', '5678') - - assert len(mock_call_later().cancel.mock_calls) == 1 - assert len(mock_call_later_2.mock_calls) == 1 - - -@asyncio.coroutine -def test_loading_saving_data(hass, registry): +async def test_loading_saving_data(hass, registry): """Test that we load/save data correctly.""" orig_entry1 = registry.async_get_or_create('light', 'hue', '1234') orig_entry2 = registry.async_get_or_create( @@ -91,18 +75,11 @@ def test_loading_saving_data(hass, registry): assert len(registry.entities) == 2 - with patch(YAML__OPEN_PATH, mock_open(), create=True) as mock_write: - yield from registry._async_save() - - # Mock open calls are: open file, context enter, write, context leave - written = mock_write.mock_calls[2][1][0] - # Now load written data in new registry registry2 = entity_registry.EntityRegistry(hass) + registry2._store = registry._store - with patch('os.path.isfile', return_value=True), \ - patch(YAML__OPEN_PATH, mock_open(read_data=written), create=True): - yield from registry2._async_load() + await registry2.async_load() # Ensure same order assert list(registry.entities) == list(registry2.entities) @@ -139,32 +116,37 @@ def test_is_registered(registry): assert not registry.async_is_registered('light.non_existing') -@asyncio.coroutine -def test_loading_extra_values(hass): +async def test_loading_extra_values(hass, hass_storage): """Test we load extra data from the registry.""" - written = """ -test.named: - platform: super_platform - unique_id: with-name - name: registry override -test.no_name: - platform: super_platform - unique_id: without-name -test.disabled_user: - platform: super_platform - unique_id: disabled-user - disabled_by: user -test.disabled_hass: - platform: super_platform - unique_id: disabled-hass - disabled_by: hass -""" + hass_storage[entity_registry.STORAGE_KEY] = { + 'version': entity_registry.STORAGE_VERSION, + 'data': { + 'entities': [ + { + 'entity_id': 'test.named', + 'platform': 'super_platform', + 'unique_id': 'with-name', + 'name': 'registry override', + }, { + 'entity_id': 'test.no_name', + 'platform': 'super_platform', + 'unique_id': 'without-name', + }, { + 'entity_id': 'test.disabled_user', + 'platform': 'super_platform', + 'unique_id': 'disabled-user', + 'disabled_by': 'user', + }, { + 'entity_id': 'test.disabled_hass', + 'platform': 'super_platform', + 'unique_id': 'disabled-hass', + 'disabled_by': 'hass', + } + ] + } + } - registry = entity_registry.EntityRegistry(hass) - - with patch('os.path.isfile', return_value=True), \ - patch(YAML__OPEN_PATH, mock_open(read_data=written), create=True): - yield from registry._async_load() + registry = await entity_registry.async_get_registry(hass) entry_with_name = registry.async_get_or_create( 'test', 'super_platform', 'with-name') @@ -202,3 +184,31 @@ async def test_updating_config_entry_id(registry): 'light', 'hue', '5678', config_entry_id='mock-id-2') assert entry.entity_id == entry2.entity_id assert entry2.config_entry_id == 'mock-id-2' + + +async def test_migration(hass): + """Test migration from old data to new.""" + old_conf = { + 'light.kitchen': { + 'config_entry_id': 'test-config-id', + 'unique_id': 'test-unique', + 'platform': 'test-platform', + 'name': 'Test Name', + 'disabled_by': 'hass', + } + } + with patch('os.path.isfile', return_value=True), patch('os.remove'), \ + patch('homeassistant.helpers.entity_registry.load_yaml', + return_value=old_conf): + registry = await entity_registry.async_get_registry(hass) + + assert registry.async_is_registered('light.kitchen') + entry = registry.async_get_or_create( + domain='light', + platform='test-platform', + unique_id='test-unique', + config_entry_id='test-config-id', + ) + assert entry.name == 'Test Name' + assert entry.disabled_by == 'hass' + assert entry.config_entry_id == 'test-config-id' diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py index b35b2596802..6cb75899d35 100644 --- a/tests/helpers/test_storage.py +++ b/tests/helpers/test_storage.py @@ -141,11 +141,10 @@ async def test_migrator_no_existing_config(hass, store, hass_storage): async def test_migrator_existing_config(hass, store, hass_storage): """Test migrating existing config.""" with patch('os.path.isfile', return_value=True), \ - patch('os.remove') as mock_remove, \ - patch('homeassistant.util.json.load_json', - return_value={'old': 'config'}): + patch('os.remove') as mock_remove: data = await storage.async_migrator( - hass, 'old-path', store) + hass, 'old-path', store, + old_conf_load_func=lambda _: {'old': 'config'}) assert len(mock_remove.mock_calls) == 1 assert data == {'old': 'config'} @@ -163,12 +162,11 @@ async def test_migrator_transforming_config(hass, store, hass_storage): return {'new': old_config['old']} with patch('os.path.isfile', return_value=True), \ - patch('os.remove') as mock_remove, \ - patch('homeassistant.util.json.load_json', - return_value={'old': 'config'}): + patch('os.remove') as mock_remove: data = await storage.async_migrator( hass, 'old-path', store, - old_conf_migrate_func=old_conf_migrate_func) + old_conf_migrate_func=old_conf_migrate_func, + old_conf_load_func=lambda _: {'old': 'config'}) assert len(mock_remove.mock_calls) == 1 assert data == {'new': 'config'}