From aeb21596a0acb032cf494893eb236bad44cdc006 Mon Sep 17 00:00:00 2001 From: mvn23 Date: Wed, 3 Oct 2018 23:12:21 +0200 Subject: [PATCH] Fix counter restore. (#17101) Add config option to disable restore (always use initial value on restart). Add unit tests for restore config option. --- homeassistant/components/counter/__init__.py | 20 ++++++----- tests/components/counter/test_init.py | 35 ++++++++++++++++++-- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index 9ef4d4374ce..d67c93c0d6e 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -20,6 +20,7 @@ ATTR_INITIAL = 'initial' ATTR_STEP = 'step' CONF_INITIAL = 'initial' +CONF_RESTORE = 'restore' CONF_STEP = 'step' DEFAULT_INITIAL = 0 @@ -43,6 +44,7 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.positive_int, vol.Optional(CONF_NAME): cv.string, + vol.Optional(CONF_RESTORE, default=True): cv.boolean, vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int, }, None) }) @@ -61,10 +63,11 @@ async def async_setup(hass, config): name = cfg.get(CONF_NAME) initial = cfg.get(CONF_INITIAL) + restore = cfg.get(CONF_RESTORE) step = cfg.get(CONF_STEP) icon = cfg.get(CONF_ICON) - entities.append(Counter(object_id, name, initial, step, icon)) + entities.append(Counter(object_id, name, initial, restore, step, icon)) if not entities: return False @@ -86,10 +89,11 @@ async def async_setup(hass, config): class Counter(Entity): """Representation of a counter.""" - def __init__(self, object_id, name, initial, step, icon): + def __init__(self, object_id, name, initial, restore, step, icon): """Initialize a counter.""" self.entity_id = ENTITY_ID_FORMAT.format(object_id) self._name = name + self._restore = restore self._step = step self._state = self._initial = initial self._icon = icon @@ -124,12 +128,12 @@ class Counter(Entity): async def async_added_to_hass(self): """Call when entity about to be added to Home Assistant.""" - # If not None, we got an initial value. - if self._state is not None: - return - - state = await async_get_last_state(self.hass, self.entity_id) - self._state = state and state.state == state + # __init__ will set self._state to self._initial, only override + # if needed. + if self._restore: + state = await async_get_last_state(self.hass, self.entity_id) + if state is not None: + self._state = int(state.state) async def async_decrement(self): """Decrement the counter.""" diff --git a/tests/components/counter/test_init.py b/tests/components/counter/test_init.py index e5e0ee594ac..929d96d4650 100644 --- a/tests/components/counter/test_init.py +++ b/tests/components/counter/test_init.py @@ -7,7 +7,7 @@ import logging from homeassistant.core import CoreState, State, Context from homeassistant.setup import setup_component, async_setup_component from homeassistant.components.counter import ( - DOMAIN, CONF_INITIAL, CONF_STEP, CONF_NAME, CONF_ICON) + DOMAIN, CONF_INITIAL, CONF_RESTORE, CONF_STEP, CONF_NAME, CONF_ICON) from homeassistant.const import (ATTR_ICON, ATTR_FRIENDLY_NAME) from tests.common import (get_test_home_assistant, mock_restore_cache) @@ -55,6 +55,7 @@ class TestCounter(unittest.TestCase): CONF_NAME: 'Hello World', CONF_ICON: 'mdi:work', CONF_INITIAL: 10, + CONF_RESTORE: False, CONF_STEP: 5, } } @@ -172,9 +173,12 @@ def test_initial_state_overrules_restore_state(hass): yield from async_setup_component(hass, DOMAIN, { DOMAIN: { - 'test1': {}, + 'test1': { + CONF_RESTORE: False, + }, 'test2': { CONF_INITIAL: 10, + CONF_RESTORE: False, }, }}) @@ -187,6 +191,33 @@ def test_initial_state_overrules_restore_state(hass): assert int(state.state) == 10 +@asyncio.coroutine +def test_restore_state_overrules_initial_state(hass): + """Ensure states are restored on startup.""" + mock_restore_cache(hass, ( + State('counter.test1', '11'), + State('counter.test2', '-22'), + )) + + hass.state = CoreState.starting + + yield from async_setup_component(hass, DOMAIN, { + DOMAIN: { + 'test1': {}, + 'test2': { + CONF_INITIAL: 10, + }, + }}) + + state = hass.states.get('counter.test1') + assert state + assert int(state.state) == 11 + + state = hass.states.get('counter.test2') + assert state + assert int(state.state) == -22 + + @asyncio.coroutine def test_no_initial_state_and_no_restore_state(hass): """Ensure that entity is create without initial and restore feature."""