Properly handle the case when a group includes itself. (#8398)

* Properly handle the case when a group includes itself.

* Fix lint
This commit is contained in:
Andrey 2017-07-08 19:20:11 +03:00 committed by Paulus Schoutsen
parent 57c5ed33ee
commit c5bf4fe339
2 changed files with 41 additions and 25 deletions

View File

@ -184,7 +184,6 @@ def expand_entity_ids(hass, entity_ids):
Async friendly. Async friendly.
""" """
found_ids = [] found_ids = []
for entity_id in entity_ids: for entity_id in entity_ids:
if not isinstance(entity_id, str): if not isinstance(entity_id, str):
continue continue
@ -196,9 +195,13 @@ def expand_entity_ids(hass, entity_ids):
domain, _ = ha.split_entity_id(entity_id) domain, _ = ha.split_entity_id(entity_id)
if domain == DOMAIN: if domain == DOMAIN:
child_entities = get_entity_ids(hass, entity_id)
if entity_id in child_entities:
child_entities = list(child_entities)
child_entities.remove(entity_id)
found_ids.extend( found_ids.extend(
ent_id for ent_id ent_id for ent_id
in expand_entity_ids(hass, get_entity_ids(hass, entity_id)) in expand_entity_ids(hass, child_entities)
if ent_id not in found_ids) if ent_id not in found_ids)
else: else:
@ -223,7 +226,6 @@ def get_entity_ids(hass, entity_id, domain_filter=None):
return [] return []
entity_ids = group.attributes[ATTR_ENTITY_ID] entity_ids = group.attributes[ATTR_ENTITY_ID]
if not domain_filter: if not domain_filter:
return entity_ids return entity_ids

View File

@ -150,6 +150,20 @@ class TestComponentsGroup(unittest.TestCase):
sorted(group.expand_entity_ids( sorted(group.expand_entity_ids(
self.hass, ['light.bowl', test_group.entity_id]))) self.hass, ['light.bowl', test_group.entity_id])))
def test_expand_entity_ids_recursive(self):
"""Test expand_entity_ids method with a group that contains itself."""
self.hass.states.set('light.Bowl', STATE_ON)
self.hass.states.set('light.Ceiling', STATE_OFF)
test_group = group.Group.create_group(
self.hass,
'init_group',
['light.Bowl', 'light.Ceiling', 'group.init_group'],
False)
self.assertEqual(sorted(['light.ceiling', 'light.bowl']),
sorted(group.expand_entity_ids(
self.hass, [test_group.entity_id])))
def test_expand_entity_ids_ignores_non_strings(self): def test_expand_entity_ids_ignores_non_strings(self):
"""Test that non string elements in lists are ignored.""" """Test that non string elements in lists are ignored."""
self.assertEqual([], group.expand_entity_ids(self.hass, [5, True])) self.assertEqual([], group.expand_entity_ids(self.hass, [5, True]))
@ -226,11 +240,11 @@ class TestComponentsGroup(unittest.TestCase):
group_conf = OrderedDict() group_conf = OrderedDict()
group_conf['second_group'] = { group_conf['second_group'] = {
'entities': 'light.Bowl, ' + test_group.entity_id, 'entities': 'light.Bowl, ' + test_group.entity_id,
'icon': 'mdi:work', 'icon': 'mdi:work',
'view': True, 'view': True,
'control': 'hidden', 'control': 'hidden',
} }
group_conf['test_group'] = 'hello.world,sensor.happy' group_conf['test_group'] = 'hello.world,sensor.happy'
group_conf['empty_group'] = {'name': 'Empty Group', 'entities': None} group_conf['empty_group'] = {'name': 'Empty Group', 'entities': None}
@ -275,8 +289,8 @@ class TestComponentsGroup(unittest.TestCase):
self.hass, 'light', ['light.test_1', 'light.test_2']) self.hass, 'light', ['light.test_1', 'light.test_2'])
group.Group.create_group( group.Group.create_group(
self.hass, 'switch', ['switch.test_1', 'switch.test_2']) self.hass, 'switch', ['switch.test_1', 'switch.test_2'])
group.Group.create_group(self.hass, 'group_of_groups', ['group.light', group.Group.create_group(
'group.switch']) self.hass, 'group_of_groups', ['group.light', 'group.switch'])
self.assertEqual( self.assertEqual(
['light.test_1', 'light.test_2', 'switch.test_1', 'switch.test_2'], ['light.test_1', 'light.test_2', 'switch.test_1', 'switch.test_2'],
@ -325,27 +339,26 @@ class TestComponentsGroup(unittest.TestCase):
def test_reloading_groups(self): def test_reloading_groups(self):
"""Test reloading the group config.""" """Test reloading the group config."""
assert setup_component(self.hass, 'group', {'group': { assert setup_component(self.hass, 'group', {'group': {
'second_group': { 'second_group': {
'entities': 'light.Bowl', 'entities': 'light.Bowl',
'icon': 'mdi:work', 'icon': 'mdi:work',
'view': True, 'view': True,
}, },
'test_group': 'hello.world,sensor.happy', 'test_group': 'hello.world,sensor.happy',
'empty_group': {'name': 'Empty Group', 'entities': None}, 'empty_group': {'name': 'Empty Group', 'entities': None},
} }})
})
assert sorted(self.hass.states.entity_ids()) == \ assert sorted(self.hass.states.entity_ids()) == \
['group.empty_group', 'group.second_group', 'group.test_group'] ['group.empty_group', 'group.second_group', 'group.test_group']
assert self.hass.bus.listeners['state_changed'] == 3 assert self.hass.bus.listeners['state_changed'] == 3
with patch('homeassistant.config.load_yaml_config_file', return_value={ with patch('homeassistant.config.load_yaml_config_file', return_value={
'group': { 'group': {
'hello': { 'hello': {
'entities': 'light.Bowl', 'entities': 'light.Bowl',
'icon': 'mdi:work', 'icon': 'mdi:work',
'view': True, 'view': True,
}}}): }}}):
group.reload(self.hass) group.reload(self.hass)
self.hass.block_till_done() self.hass.block_till_done()
@ -395,6 +408,7 @@ def test_service_group_services(hass):
assert hass.services.has_service('group', group.SERVICE_REMOVE) assert hass.services.has_service('group', group.SERVICE_REMOVE)
# pylint: disable=invalid-name
@asyncio.coroutine @asyncio.coroutine
def test_service_group_set_group_remove_group(hass): def test_service_group_set_group_remove_group(hass):
"""Check if service are available.""" """Check if service are available."""