Make State class immutable

This commit is contained in:
Paulus Schoutsen 2016-02-09 23:27:01 -08:00
parent 70a528c04b
commit b0948bef5f
6 changed files with 10 additions and 29 deletions

View File

@ -141,7 +141,7 @@ class Configurator(object):
state = self.hass.states.get(entity_id) state = self.hass.states.get(entity_id)
new_data = state.attributes new_data = dict(state.attributes)
new_data[ATTR_ERRORS] = error new_data[ATTR_ERRORS] = error
self.hass.states.set(entity_id, STATE_CONFIGURE, new_data) self.hass.states.set(entity_id, STATE_CONFIGURE, new_data)

View File

@ -232,7 +232,7 @@ class Recorder(threading.Thread):
else: else:
state_domain = state.domain state_domain = state.domain
state_state = state.state state_state = state.state
state_attr = json.dumps(state.attributes) state_attr = json.dumps(dict(state.attributes))
last_changed = state.last_changed last_changed = state.last_changed
last_updated = state.last_updated last_updated = state.last_updated

View File

@ -10,6 +10,7 @@ import time
import logging import logging
import signal import signal
import threading import threading
from types import MappingProxyType
import enum import enum
import functools as ft import functools as ft
from collections import namedtuple from collections import namedtuple
@ -353,7 +354,7 @@ class State(object):
self.entity_id = entity_id.lower() self.entity_id = entity_id.lower()
self.state = state self.state = state
self.attributes = attributes or {} self.attributes = MappingProxyType(attributes or {})
self.last_updated = dt_util.strip_microseconds( self.last_updated = dt_util.strip_microseconds(
last_updated or dt_util.utcnow()) last_updated or dt_util.utcnow())
@ -381,12 +382,6 @@ class State(object):
self.attributes.get(ATTR_FRIENDLY_NAME) or self.attributes.get(ATTR_FRIENDLY_NAME) or
self.object_id.replace('_', ' ')) self.object_id.replace('_', ' '))
def copy(self):
"""Return a copy of the state."""
return State(self.entity_id, self.state,
dict(self.attributes), self.last_changed,
self.last_updated)
def as_dict(self): def as_dict(self):
"""Return a dict representation of the State. """Return a dict representation of the State.
@ -395,7 +390,7 @@ class State(object):
""" """
return {'entity_id': self.entity_id, return {'entity_id': self.entity_id,
'state': self.state, 'state': self.state,
'attributes': self.attributes, 'attributes': dict(self.attributes),
'last_changed': dt_util.datetime_to_str(self.last_changed), 'last_changed': dt_util.datetime_to_str(self.last_changed),
'last_updated': dt_util.datetime_to_str(self.last_updated)} 'last_updated': dt_util.datetime_to_str(self.last_updated)}
@ -459,14 +454,11 @@ class StateMachine(object):
def all(self): def all(self):
"""Create a list of all states.""" """Create a list of all states."""
with self._lock: with self._lock:
return [state.copy() for state in self._states.values()] return list(self._states.values())
def get(self, entity_id): def get(self, entity_id):
"""Retrieve state of entity_id or None if not found.""" """Retrieve state of entity_id or None if not found."""
state = self._states.get(entity_id.lower()) return self._states.get(entity_id.lower())
# Make a copy so people won't mutate the state
return state.copy() if state else None
def is_state(self, entity_id, state): def is_state(self, entity_id, state):
"""Test if entity exists and is specified state.""" """Test if entity exists and is specified state."""

View File

@ -85,7 +85,7 @@ def reproduce_state(hass, states, blocking=False):
# We group service calls for entities by service call # We group service calls for entities by service call
# json used to create a hashable version of dict with maybe lists in it # json used to create a hashable version of dict with maybe lists in it
key = (service_domain, service, key = (service_domain, service,
json.dumps(state.attributes, sort_keys=True)) json.dumps(dict(state.attributes), sort_keys=True))
to_call[key].append(state.entity_id) to_call[key].append(state.entity_id)
for (service_domain, service, service_data), entity_ids in to_call.items(): for (service_domain, service, service_data), entity_ids in to_call.items():

View File

@ -15,6 +15,7 @@ import socket
import random import random
import string import string
from functools import wraps from functools import wraps
from types import MappingProxyType
from .dt import datetime_to_local_str, utcnow from .dt import datetime_to_local_str, utcnow
@ -42,7 +43,7 @@ def slugify(text):
def repr_helper(inp): def repr_helper(inp):
""" Helps creating a more readable string representation of objects. """ """ Helps creating a more readable string representation of objects. """
if isinstance(inp, dict): if isinstance(inp, (dict, MappingProxyType)):
return ", ".join( return ", ".join(
repr_helper(key)+"="+repr_helper(item) for key, item repr_helper(key)+"="+repr_helper(item) for key, item
in inp.items()) in inp.items())

View File

@ -267,18 +267,6 @@ class TestState(unittest.TestCase):
{ATTR_FRIENDLY_NAME: name}) {ATTR_FRIENDLY_NAME: name})
self.assertEqual(name, state.name) self.assertEqual(name, state.name)
def test_copy(self):
state = ha.State('domain.hello', 'world', {'some': 'attr'})
# Patch dt_util.utcnow() so we know last_updated got copied too
with patch('homeassistant.core.dt_util.utcnow',
return_value=dt_util.utcnow() + timedelta(seconds=10)):
copy = state.copy()
self.assertEqual(state.entity_id, copy.entity_id)
self.assertEqual(state.state, copy.state)
self.assertEqual(state.attributes, copy.attributes)
self.assertEqual(state.last_changed, copy.last_changed)
self.assertEqual(state.last_updated, copy.last_updated)
def test_dict_conversion(self): def test_dict_conversion(self):
state = ha.State('domain.hello', 'world', {'some': 'attr'}) state = ha.State('domain.hello', 'world', {'some': 'attr'})
self.assertEqual(state, ha.State.from_dict(state.as_dict())) self.assertEqual(state, ha.State.from_dict(state.as_dict()))