From 89a548252a89eeaf72853282b45ce3f307609677 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 26 Dec 2014 23:26:39 -0800 Subject: [PATCH] StateMachine is now case insensitive for entity ids --- ha_test/test_core.py | 13 +++++++++++++ homeassistant/__init__.py | 23 +++++++++++++++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/ha_test/test_core.py b/ha_test/test_core.py index 73513eee502..c7a35d83842 100644 --- a/ha_test/test_core.py +++ b/ha_test/test_core.py @@ -285,6 +285,19 @@ class TestStateMachine(unittest.TestCase): self.assertEqual(1, len(specific_runs)) self.assertEqual(3, len(wildcard_runs)) + def test_case_insensitivty(self): + runs = [] + + self.states.track_change( + 'light.BoWl', lambda a, b, c: runs.append(1), + ha.MATCH_ALL, ha.MATCH_ALL) + + self.states.set('light.BOWL', 'off') + self.bus._pool.block_till_done() + + self.assertTrue(self.states.is_state('light.bowl', 'off')) + self.assertEqual(1, len(runs)) + class TestServiceCall(unittest.TestCase): """ Test ServiceCall class. """ diff --git a/homeassistant/__init__.py b/homeassistant/__init__.py index 336fa6b433d..e0a2316dfd8 100644 --- a/homeassistant/__init__.py +++ b/homeassistant/__init__.py @@ -15,6 +15,8 @@ import re import datetime as dt import functools as ft +from requests.structures import CaseInsensitiveDict + from homeassistant.const import ( EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED, @@ -482,15 +484,18 @@ class StateMachine(object): """ Helper class that tracks the state of different entities. """ def __init__(self, bus): - self._states = {} + self._states = CaseInsensitiveDict() self._bus = bus self._lock = threading.Lock() def entity_ids(self, domain_filter=None): """ List of entity ids that are being tracked. """ if domain_filter is not None: - return [entity_id for entity_id in self._states.keys() - if util.split_entity_id(entity_id)[0] == domain_filter] + domain_filter = domain_filter.lower() + + return [state.entity_id for key, state + in self._states.lower_items() + if util.split_entity_id(key)[0] == domain_filter] else: return list(self._states.keys()) @@ -524,9 +529,9 @@ class StateMachine(object): self._states[entity_id].state == state) def remove(self, entity_id): - """ Removes a entity from the state machine. + """ Removes an entity from the state machine. - Returns boolean to indicate if a entity was removed. """ + Returns boolean to indicate if an entity was removed. """ with self._lock: return self._states.pop(entity_id, None) is not None @@ -567,14 +572,16 @@ class StateMachine(object): from_state = _process_match_param(from_state) to_state = _process_match_param(to_state) - # Ensure it is a list with entity ids we want to match on + # Ensure it is a lowercase list with entity ids we want to match on if isinstance(entity_ids, str): - entity_ids = [entity_ids] + entity_ids = [entity_ids.lower()] + else: + entity_ids = [entity_id.lower() for entity_id in entity_ids] @ft.wraps(action) def state_listener(event): """ The listener that listens for specific state changes. """ - if event.data['entity_id'] in entity_ids and \ + if event.data['entity_id'].lower() in entity_ids and \ 'old_state' in event.data and \ _matcher(event.data['old_state'].state, from_state) and \ _matcher(event.data['new_state'].state, to_state):