diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index 98329bc417a..5580518a9a3 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -32,10 +32,17 @@ SERVICE_RESET = "reset" SERVICE_CONFIGURE = "configure" +def _none_to_empty_dict(value): + if value is None: + return {} + return value + + CONFIG_SCHEMA = vol.Schema( { DOMAIN: cv.schema_with_slug_keys( - vol.Any( + vol.All( + _none_to_empty_dict, { vol.Optional(CONF_ICON): cv.icon, vol.Optional( @@ -51,7 +58,6 @@ CONFIG_SCHEMA = vol.Schema( vol.Optional(CONF_RESTORE, default=True): cv.boolean, vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int, }, - None, ) ) }, @@ -70,12 +76,12 @@ async def async_setup(hass, config): cfg = {} name = cfg.get(CONF_NAME) - initial = cfg.get(CONF_INITIAL) - restore = cfg.get(CONF_RESTORE) - step = cfg.get(CONF_STEP) + initial = cfg[CONF_INITIAL] + restore = cfg[CONF_RESTORE] + step = cfg[CONF_STEP] icon = cfg.get(CONF_ICON) - minimum = cfg.get(CONF_MINIMUM) - maximum = cfg.get(CONF_MAXIMUM) + minimum = cfg[CONF_MINIMUM] + maximum = cfg[CONF_MAXIMUM] entities.append( Counter(object_id, name, initial, minimum, maximum, restore, step, icon) diff --git a/tests/components/counter/test_init.py b/tests/components/counter/test_init.py index 3e85a080806..f5ff825e7fb 100644 --- a/tests/components/counter/test_init.py +++ b/tests/components/counter/test_init.py @@ -3,11 +3,15 @@ import logging from homeassistant.components.counter import ( + ATTR_INITIAL, + ATTR_STEP, CONF_ICON, CONF_INITIAL, CONF_NAME, CONF_RESTORE, CONF_STEP, + DEFAULT_INITIAL, + DEFAULT_STEP, DOMAIN, ) from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON @@ -48,6 +52,7 @@ async def test_config_options(hass): CONF_RESTORE: False, CONF_STEP: 5, }, + "test_3": None, } } @@ -56,14 +61,16 @@ async def test_config_options(hass): _LOGGER.debug("ENTITIES: %s", hass.states.async_entity_ids()) - assert count_start + 2 == len(hass.states.async_entity_ids()) + assert count_start + 3 == len(hass.states.async_entity_ids()) await hass.async_block_till_done() state_1 = hass.states.get("counter.test_1") state_2 = hass.states.get("counter.test_2") + state_3 = hass.states.get("counter.test_3") assert state_1 is not None assert state_2 is not None + assert state_3 is not None assert 0 == int(state_1.state) assert ATTR_ICON not in state_1.attributes @@ -73,6 +80,9 @@ async def test_config_options(hass): assert "Hello World" == state_2.attributes.get(ATTR_FRIENDLY_NAME) assert "mdi:work" == state_2.attributes.get(ATTR_ICON) + assert DEFAULT_INITIAL == state_3.attributes.get(ATTR_INITIAL) + assert DEFAULT_STEP == state_3.attributes.get(ATTR_STEP) + async def test_methods(hass): """Test increment, decrement, and reset methods."""