mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Add storage helper and migrate config entries (#15045)
* Add storage helper * Migrate config entries to use the storage helper * Make sure tests do not do I/O * Lint * Add versions to stored data * Add more instance variables * Make migrator load config if nothing to migrate * Address comments
This commit is contained in:
parent
672a3c7178
commit
ae51dc08bf
@ -225,7 +225,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
||||
hass, config, add_devices, config_path, discovery_info=None)
|
||||
return False
|
||||
else:
|
||||
config_file = save_json(config_path, DEFAULT_CONFIG)
|
||||
save_json(config_path, DEFAULT_CONFIG)
|
||||
request_app_setup(
|
||||
hass, config, add_devices, config_path, discovery_info=None)
|
||||
return False
|
||||
|
@ -112,15 +112,13 @@ the flow from the config panel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from . import data_entry_flow
|
||||
from .core import callback
|
||||
from .exceptions import HomeAssistantError
|
||||
from .setup import async_setup_component, async_process_deps_reqs
|
||||
from .util.json import load_json, save_json
|
||||
from .util.decorator import Registry
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component, async_process_deps_reqs
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -136,6 +134,10 @@ FLOWS = [
|
||||
]
|
||||
|
||||
|
||||
STORAGE_KEY = 'core.config_entries'
|
||||
STORAGE_VERSION = 1
|
||||
|
||||
# Deprecated since 0.73
|
||||
PATH_CONFIG = '.config_entries.json'
|
||||
|
||||
SAVE_DELAY = 1
|
||||
@ -271,7 +273,7 @@ class ConfigEntries:
|
||||
hass, self._async_create_flow, self._async_finish_flow)
|
||||
self._hass_config = hass_config
|
||||
self._entries = None
|
||||
self._sched_save = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
def async_domains(self):
|
||||
@ -305,7 +307,7 @@ class ConfigEntries:
|
||||
raise UnknownEntry
|
||||
|
||||
entry = self._entries.pop(found)
|
||||
self._async_schedule_save()
|
||||
await self._async_schedule_save()
|
||||
|
||||
unloaded = await entry.async_unload(self.hass)
|
||||
|
||||
@ -314,14 +316,14 @@ class ConfigEntries:
|
||||
}
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the config."""
|
||||
path = self.hass.config.path(PATH_CONFIG)
|
||||
if not os.path.isfile(path):
|
||||
self._entries = []
|
||||
return
|
||||
"""Handle loading the config."""
|
||||
# Migrating for config entries stored before 0.73
|
||||
config = await self.hass.helpers.storage.async_migrator(
|
||||
self.hass.config.path(PATH_CONFIG), self._store,
|
||||
old_conf_migrate_func=_old_conf_migrator
|
||||
)
|
||||
|
||||
entries = await self.hass.async_add_job(load_json, path)
|
||||
self._entries = [ConfigEntry(**entry) for entry in entries]
|
||||
self._entries = [ConfigEntry(**entry) for entry in config['entries']]
|
||||
|
||||
async def async_forward_entry_setup(self, entry, component):
|
||||
"""Forward the setup of an entry to a different component.
|
||||
@ -372,7 +374,7 @@ class ConfigEntries:
|
||||
source=result['source'],
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._async_schedule_save()
|
||||
await self._async_schedule_save()
|
||||
|
||||
# Setup entry
|
||||
if entry.domain in self.hass.config.components:
|
||||
@ -416,20 +418,14 @@ class ConfigEntries:
|
||||
|
||||
return handler()
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self):
|
||||
"""Schedule saving the entity registry."""
|
||||
if self._sched_save is not None:
|
||||
self._sched_save.cancel()
|
||||
|
||||
self._sched_save = self.hass.loop.call_later(
|
||||
SAVE_DELAY, self.hass.async_add_job, self._async_save
|
||||
)
|
||||
|
||||
async def _async_save(self):
|
||||
async def _async_schedule_save(self):
|
||||
"""Save the entity registry to a file."""
|
||||
self._sched_save = None
|
||||
data = [entry.as_dict() for entry in self._entries]
|
||||
data = {
|
||||
'entries': [entry.as_dict() for entry in self._entries]
|
||||
}
|
||||
await self._store.async_save(data, delay=SAVE_DELAY)
|
||||
|
||||
await self.hass.async_add_job(
|
||||
save_json, self.hass.config.path(PATH_CONFIG), data)
|
||||
|
||||
async def _old_conf_migrator(old_config):
|
||||
"""Migrate the pre-0.73 config format to the latest version."""
|
||||
return {'entries': old_config}
|
||||
|
@ -230,6 +230,20 @@ class HomeAssistant(object):
|
||||
|
||||
return task
|
||||
|
||||
@callback
|
||||
def async_add_executor_job(
|
||||
self,
|
||||
target: Callable[..., Any],
|
||||
*args: Any) -> asyncio.tasks.Task:
|
||||
"""Add an executor job from within the event loop."""
|
||||
task = self.loop.run_in_executor(None, target, *args)
|
||||
|
||||
# If a task is scheduled
|
||||
if self._track_task:
|
||||
self._pending_tasks.append(task)
|
||||
|
||||
return task
|
||||
|
||||
@callback
|
||||
def async_track_tasks(self):
|
||||
"""Track tasks so you can wait for all tasks to be done."""
|
||||
|
157
homeassistant/helpers/storage.py
Normal file
157
homeassistant/helpers/storage.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""Helper to help store data."""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import json
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
|
||||
STORAGE_DIR = '.storage'
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None):
|
||||
"""Helper function to migrate old data to a store and then load data.
|
||||
|
||||
async def old_conf_migrate_func(old_data)
|
||||
"""
|
||||
def load_old_config():
|
||||
"""Helper to load old config."""
|
||||
if not os.path.isfile(old_path):
|
||||
return None
|
||||
|
||||
return json.load_json(old_path)
|
||||
|
||||
config = await hass.async_add_executor_job(load_old_config)
|
||||
|
||||
if config is None:
|
||||
return await store.async_load()
|
||||
|
||||
if old_conf_migrate_func is not None:
|
||||
config = await old_conf_migrate_func(config)
|
||||
|
||||
await store.async_save(config)
|
||||
await hass.async_add_executor_job(os.remove, old_path)
|
||||
return config
|
||||
|
||||
|
||||
@bind_hass
|
||||
class Store:
|
||||
"""Class to help storing data."""
|
||||
|
||||
def __init__(self, hass, version: int, key: str):
|
||||
"""Initialize storage class."""
|
||||
self.version = version
|
||||
self.key = key
|
||||
self.hass = hass
|
||||
self._data = None
|
||||
self._unsub_delay_listener = None
|
||||
self._unsub_stop_listener = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""Return the config path."""
|
||||
return self.hass.config.path(STORAGE_DIR, self.key)
|
||||
|
||||
async def async_load(self):
|
||||
"""Load data.
|
||||
|
||||
If the expected version does not match the given version, the migrate
|
||||
function will be invoked with await migrate_func(version, config).
|
||||
"""
|
||||
if self._data is not None:
|
||||
data = self._data
|
||||
else:
|
||||
data = await self.hass.async_add_executor_job(
|
||||
json.load_json, self.path, None)
|
||||
|
||||
if data is None:
|
||||
return {}
|
||||
|
||||
if data['version'] == self.version:
|
||||
return data['data']
|
||||
|
||||
return await self._async_migrate_func(data['version'], data['data'])
|
||||
|
||||
async def async_save(self, data: Dict, *, delay: Optional[int] = None):
|
||||
"""Save data with an optional delay."""
|
||||
self._data = {
|
||||
'version': self.version,
|
||||
'key': self.key,
|
||||
'data': data,
|
||||
}
|
||||
|
||||
self._async_cleanup_delay_listener()
|
||||
|
||||
if delay is None:
|
||||
self._async_cleanup_stop_listener()
|
||||
await self._async_handle_write_data()
|
||||
return
|
||||
|
||||
self._unsub_delay_listener = async_call_later(
|
||||
self.hass, delay, self._async_callback_delayed_write)
|
||||
|
||||
self._async_ensure_stop_listener()
|
||||
|
||||
@callback
|
||||
def _async_ensure_stop_listener(self):
|
||||
"""Ensure that we write if we quit before delay has passed."""
|
||||
if self._unsub_stop_listener is None:
|
||||
self._unsub_stop_listener = self.hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, self._async_callback_stop_write)
|
||||
|
||||
@callback
|
||||
def _async_cleanup_stop_listener(self):
|
||||
"""Clean up a stop listener."""
|
||||
if self._unsub_stop_listener is not None:
|
||||
self._unsub_stop_listener()
|
||||
self._unsub_stop_listener = None
|
||||
|
||||
@callback
|
||||
def _async_cleanup_delay_listener(self):
|
||||
"""Clean up a delay listener."""
|
||||
if self._unsub_delay_listener is not None:
|
||||
self._unsub_delay_listener()
|
||||
self._unsub_delay_listener = None
|
||||
|
||||
async def _async_callback_delayed_write(self, _now):
|
||||
"""Handle a delayed write callback."""
|
||||
self._unsub_delay_listener = None
|
||||
self._async_cleanup_stop_listener()
|
||||
await self._async_handle_write_data()
|
||||
|
||||
async def _async_callback_stop_write(self, _event):
|
||||
"""Handle a write because Home Assistant is stopping."""
|
||||
self._unsub_stop_listener = None
|
||||
self._async_cleanup_delay_listener()
|
||||
await self._async_handle_write_data()
|
||||
|
||||
async def _async_handle_write_data(self, *_args):
|
||||
"""Handler to handle writing the config."""
|
||||
data = self._data
|
||||
self._data = None
|
||||
|
||||
async with self._write_lock:
|
||||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
self._write_data, self.path, data)
|
||||
except (json.SerializationError, json.WriteError) as err:
|
||||
_LOGGER.error('Error writing config for %s: %s', self.key, err)
|
||||
|
||||
def _write_data(self, path: str, data: Dict):
|
||||
"""Write the data."""
|
||||
if not os.path.isdir(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
|
||||
_LOGGER.debug('Writing data for %s', self.key)
|
||||
json.save_json(path, data)
|
||||
|
||||
async def _async_migrate_func(self, old_version, old_data):
|
||||
"""Migrate to the new version."""
|
||||
raise NotImplementedError
|
@ -11,6 +11,14 @@ _LOGGER = logging.getLogger(__name__)
|
||||
_UNDEFINED = object()
|
||||
|
||||
|
||||
class SerializationError(HomeAssistantError):
|
||||
"""Error serializing the data to JSON."""
|
||||
|
||||
|
||||
class WriteError(HomeAssistantError):
|
||||
"""Error writing the data."""
|
||||
|
||||
|
||||
def load_json(filename: str, default: Union[List, Dict] = _UNDEFINED) \
|
||||
-> Union[List, Dict]:
|
||||
"""Load JSON data from a file and return as dict or list.
|
||||
@ -41,13 +49,11 @@ def save_json(filename: str, data: Union[List, Dict]):
|
||||
data = json.dumps(data, sort_keys=True, indent=4)
|
||||
with open(filename, 'w', encoding='utf-8') as fdesc:
|
||||
fdesc.write(data)
|
||||
return True
|
||||
except TypeError as error:
|
||||
_LOGGER.exception('Failed to serialize to JSON: %s',
|
||||
filename)
|
||||
raise HomeAssistantError(error)
|
||||
raise SerializationError(error)
|
||||
except OSError as error:
|
||||
_LOGGER.exception('Saving JSON file failed: %s',
|
||||
filename)
|
||||
raise HomeAssistantError(error)
|
||||
return False
|
||||
raise WriteError(error)
|
||||
|
@ -14,7 +14,7 @@ from homeassistant import auth, core as ha, data_entry_flow, config_entries
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.helpers import (
|
||||
intent, entity, restore_state, entity_registry,
|
||||
intent, entity, restore_state, entity_registry,
|
||||
entity_platform)
|
||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||
import homeassistant.util.dt as date_util
|
||||
@ -110,8 +110,6 @@ def get_test_home_assistant():
|
||||
def async_test_home_assistant(loop):
|
||||
"""Return a Home Assistant object pointing at test config dir."""
|
||||
hass = ha.HomeAssistant(loop)
|
||||
hass.config_entries = config_entries.ConfigEntries(hass, {})
|
||||
hass.config_entries._entries = []
|
||||
hass.config.async_load = Mock()
|
||||
store = auth.AuthStore(hass)
|
||||
hass.auth = auth.AuthManager(hass, store, {})
|
||||
@ -137,6 +135,10 @@ def async_test_home_assistant(loop):
|
||||
hass.config.units = METRIC_SYSTEM
|
||||
hass.config.skip_pip = True
|
||||
|
||||
hass.config_entries = config_entries.ConfigEntries(hass, {})
|
||||
hass.config_entries._entries = []
|
||||
hass.config_entries._store._async_ensure_stop_listener = lambda: None
|
||||
|
||||
hass.state = ha.CoreState.running
|
||||
|
||||
# Mock async_start
|
||||
|
158
tests/helpers/test_storage.py
Normal file
158
tests/helpers/test_storage.py
Normal file
@ -0,0 +1,158 @@
|
||||
"""Tests for the storage helper."""
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.helpers import storage
|
||||
from homeassistant.util import dt
|
||||
|
||||
from tests.common import async_fire_time_changed, mock_coro
|
||||
|
||||
|
||||
MOCK_VERSION = 1
|
||||
MOCK_KEY = 'storage-test'
|
||||
MOCK_DATA = {'hello': 'world'}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_save():
|
||||
"""Fixture to mock JSON save."""
|
||||
written = []
|
||||
with patch('homeassistant.util.json.save_json',
|
||||
side_effect=lambda *args: written.append(args)):
|
||||
yield written
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load(mock_save):
|
||||
"""Fixture to mock JSON read."""
|
||||
with patch('homeassistant.util.json.load_json',
|
||||
side_effect=lambda *args: mock_save[-1][1]):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
"""Fixture of a store that prevents writing on HASS stop."""
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
store._async_ensure_stop_listener = lambda: None
|
||||
yield store
|
||||
|
||||
|
||||
async def test_loading(hass, store, mock_save, mock_load):
|
||||
"""Test we can save and load data."""
|
||||
await store.async_save(MOCK_DATA)
|
||||
data = await store.async_load()
|
||||
assert data == MOCK_DATA
|
||||
|
||||
|
||||
async def test_loading_non_existing(hass, store):
|
||||
"""Test we can save and load data."""
|
||||
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
|
||||
data = await store.async_load()
|
||||
assert data == {}
|
||||
|
||||
|
||||
async def test_saving_with_delay(hass, store, mock_save):
|
||||
"""Test saving data after a delay."""
|
||||
await store.async_save(MOCK_DATA, delay=1)
|
||||
assert len(mock_save) == 0
|
||||
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_save) == 1
|
||||
|
||||
|
||||
async def test_saving_on_stop(hass, mock_save):
|
||||
"""Test delayed saves trigger when we quit Home Assistant."""
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
await store.async_save(MOCK_DATA, delay=1)
|
||||
assert len(mock_save) == 0
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_save) == 1
|
||||
|
||||
|
||||
async def test_loading_while_delay(hass, store, mock_save, mock_load):
|
||||
"""Test we load new data even if not written yet."""
|
||||
await store.async_save({'delay': 'no'})
|
||||
assert len(mock_save) == 1
|
||||
|
||||
await store.async_save({'delay': 'yes'}, delay=1)
|
||||
assert len(mock_save) == 1
|
||||
|
||||
data = await store.async_load()
|
||||
assert data == {'delay': 'yes'}
|
||||
|
||||
|
||||
async def test_writing_while_writing_delay(hass, store, mock_save, mock_load):
|
||||
"""Test a write while a write with delay is active."""
|
||||
await store.async_save({'delay': 'yes'}, delay=1)
|
||||
assert len(mock_save) == 0
|
||||
await store.async_save({'delay': 'no'})
|
||||
assert len(mock_save) == 1
|
||||
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
await hass.async_block_till_done()
|
||||
assert len(mock_save) == 1
|
||||
|
||||
data = await store.async_load()
|
||||
assert data == {'delay': 'no'}
|
||||
|
||||
|
||||
async def test_migrator_no_existing_config(hass, store, mock_save):
|
||||
"""Test migrator with no existing config."""
|
||||
with patch('os.path.isfile', return_value=False), \
|
||||
patch.object(store, 'async_load',
|
||||
return_value=mock_coro({'cur': 'config'})):
|
||||
data = await storage.async_migrator(
|
||||
hass, 'old-path', store)
|
||||
|
||||
assert data == {'cur': 'config'}
|
||||
assert len(mock_save) == 0
|
||||
|
||||
|
||||
async def test_migrator_existing_config(hass, store, mock_save):
|
||||
"""Test migrating existing config."""
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch('os.remove') as mock_remove, \
|
||||
patch('homeassistant.util.json.load_json',
|
||||
return_value={'old': 'config'}):
|
||||
data = await storage.async_migrator(
|
||||
hass, 'old-path', store)
|
||||
|
||||
assert len(mock_remove.mock_calls) == 1
|
||||
assert data == {'old': 'config'}
|
||||
assert len(mock_save) == 1
|
||||
assert mock_save[0][1] == {
|
||||
'key': MOCK_KEY,
|
||||
'version': MOCK_VERSION,
|
||||
'data': data,
|
||||
}
|
||||
|
||||
|
||||
async def test_migrator_transforming_config(hass, store, mock_save):
|
||||
"""Test migrating config to new format."""
|
||||
async def old_conf_migrate_func(old_config):
|
||||
"""Migrate old config to new format."""
|
||||
return {'new': old_config['old']}
|
||||
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch('os.remove') as mock_remove, \
|
||||
patch('homeassistant.util.json.load_json',
|
||||
return_value={'old': 'config'}):
|
||||
data = await storage.async_migrator(
|
||||
hass, 'old-path', store,
|
||||
old_conf_migrate_func=old_conf_migrate_func)
|
||||
|
||||
assert len(mock_remove.mock_calls) == 1
|
||||
assert data == {'new': 'config'}
|
||||
assert len(mock_save) == 1
|
||||
assert mock_save[0][1] == {
|
||||
'key': MOCK_KEY,
|
||||
'version': MOCK_VERSION,
|
||||
'data': data,
|
||||
}
|
@ -1,13 +1,16 @@
|
||||
"""Test the config manager."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, loader, data_entry_flow
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt
|
||||
|
||||
from tests.common import MockModule, mock_coro, MockConfigEntry
|
||||
from tests.common import (
|
||||
MockModule, mock_coro, MockConfigEntry, async_fire_time_changed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -15,6 +18,7 @@ def manager(hass):
|
||||
"""Fixture of a loaded config manager."""
|
||||
manager = config_entries.ConfigEntries(hass, {})
|
||||
manager._entries = []
|
||||
manager._store._async_ensure_stop_listener = lambda: None
|
||||
hass.config_entries = manager
|
||||
return manager
|
||||
|
||||
@ -151,7 +155,9 @@ def test_domains_gets_uniques(manager):
|
||||
@asyncio.coroutine
|
||||
def test_saving_and_loading(hass):
|
||||
"""Test that we're saving and loading correctly."""
|
||||
loader.set_component(hass, 'test', MockModule('test'))
|
||||
loader.set_component(
|
||||
hass, 'test',
|
||||
MockModule('test', async_setup_entry=lambda *args: mock_coro(True)))
|
||||
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
@ -183,13 +189,12 @@ def test_saving_and_loading(hass):
|
||||
json_path = 'homeassistant.util.json.open'
|
||||
|
||||
with patch('homeassistant.config_entries.HANDLERS.get',
|
||||
return_value=Test2Flow), \
|
||||
patch.object(config_entries, 'SAVE_DELAY', 0):
|
||||
return_value=Test2Flow):
|
||||
yield from hass.config_entries.flow.async_init('test')
|
||||
|
||||
with patch(json_path, mock_open(), create=True) as mock_write:
|
||||
# To trigger the call_later
|
||||
yield from asyncio.sleep(0, loop=hass.loop)
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
# To execute the save
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
@ -199,7 +204,7 @@ def test_saving_and_loading(hass):
|
||||
# Now load written data in new config manager
|
||||
manager = config_entries.ConfigEntries(hass, {})
|
||||
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
with patch('os.path.isfile', return_value=False), \
|
||||
patch(json_path, mock_open(read_data=written), create=True):
|
||||
yield from manager.async_load()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user