Make State class immutable

This commit is contained in:
Paulus Schoutsen 2016-02-09 23:27:01 -08:00
parent 70a528c04b
commit b0948bef5f
6 changed files with 10 additions and 29 deletions

View File

@ -141,7 +141,7 @@ class Configurator(object):
state = self.hass.states.get(entity_id)
new_data = state.attributes
new_data = dict(state.attributes)
new_data[ATTR_ERRORS] = error
self.hass.states.set(entity_id, STATE_CONFIGURE, new_data)

View File

@ -232,7 +232,7 @@ class Recorder(threading.Thread):
else:
state_domain = state.domain
state_state = state.state
state_attr = json.dumps(state.attributes)
state_attr = json.dumps(dict(state.attributes))
last_changed = state.last_changed
last_updated = state.last_updated

View File

@ -10,6 +10,7 @@ import time
import logging
import signal
import threading
from types import MappingProxyType
import enum
import functools as ft
from collections import namedtuple
@ -353,7 +354,7 @@ class State(object):
self.entity_id = entity_id.lower()
self.state = state
self.attributes = attributes or {}
self.attributes = MappingProxyType(attributes or {})
self.last_updated = dt_util.strip_microseconds(
last_updated or dt_util.utcnow())
@ -381,12 +382,6 @@ class State(object):
self.attributes.get(ATTR_FRIENDLY_NAME) or
self.object_id.replace('_', ' '))
def copy(self):
"""Return a copy of the state."""
return State(self.entity_id, self.state,
dict(self.attributes), self.last_changed,
self.last_updated)
def as_dict(self):
"""Return a dict representation of the State.
@ -395,7 +390,7 @@ class State(object):
"""
return {'entity_id': self.entity_id,
'state': self.state,
'attributes': self.attributes,
'attributes': dict(self.attributes),
'last_changed': dt_util.datetime_to_str(self.last_changed),
'last_updated': dt_util.datetime_to_str(self.last_updated)}
@ -459,14 +454,11 @@ class StateMachine(object):
def all(self):
"""Create a list of all states."""
with self._lock:
return [state.copy() for state in self._states.values()]
return list(self._states.values())
def get(self, entity_id):
"""Retrieve state of entity_id or None if not found."""
state = self._states.get(entity_id.lower())
# Make a copy so people won't mutate the state
return state.copy() if state else None
return self._states.get(entity_id.lower())
def is_state(self, entity_id, state):
"""Test if entity exists and is specified state."""

View File

@ -85,7 +85,7 @@ def reproduce_state(hass, states, blocking=False):
# We group service calls for entities by service call
# json used to create a hashable version of dict with maybe lists in it
key = (service_domain, service,
json.dumps(state.attributes, sort_keys=True))
json.dumps(dict(state.attributes), sort_keys=True))
to_call[key].append(state.entity_id)
for (service_domain, service, service_data), entity_ids in to_call.items():

View File

@ -15,6 +15,7 @@ import socket
import random
import string
from functools import wraps
from types import MappingProxyType
from .dt import datetime_to_local_str, utcnow
@ -42,7 +43,7 @@ def slugify(text):
def repr_helper(inp):
""" Helps creating a more readable string representation of objects. """
if isinstance(inp, dict):
if isinstance(inp, (dict, MappingProxyType)):
return ", ".join(
repr_helper(key)+"="+repr_helper(item) for key, item
in inp.items())

View File

@ -267,18 +267,6 @@ class TestState(unittest.TestCase):
{ATTR_FRIENDLY_NAME: name})
self.assertEqual(name, state.name)
def test_copy(self):
state = ha.State('domain.hello', 'world', {'some': 'attr'})
# Patch dt_util.utcnow() so we know last_updated got copied too
with patch('homeassistant.core.dt_util.utcnow',
return_value=dt_util.utcnow() + timedelta(seconds=10)):
copy = state.copy()
self.assertEqual(state.entity_id, copy.entity_id)
self.assertEqual(state.state, copy.state)
self.assertEqual(state.attributes, copy.attributes)
self.assertEqual(state.last_changed, copy.last_changed)
self.assertEqual(state.last_updated, copy.last_updated)
def test_dict_conversion(self):
state = ha.State('domain.hello', 'world', {'some': 'attr'})
self.assertEqual(state, ha.State.from_dict(state.as_dict()))