Added TrackStates context manager

This commit is contained in:
Paulus Schoutsen 2014-12-14 00:32:20 -08:00
parent dfa1e1c586
commit b091e9c31c
4 changed files with 46 additions and 9 deletions

View File

@ -418,17 +418,13 @@ class State(object):
self.entity_id = entity_id self.entity_id = entity_id
self.state = state self.state = state
self.attributes = attributes or {} self.attributes = attributes or {}
last_changed = last_changed or dt.datetime.now()
# Strip microsecond from last_changed else we cannot guarantee # Strip microsecond from last_changed else we cannot guarantee
# state == State.from_dict(state.as_dict()) # state == State.from_dict(state.as_dict())
# This behavior occurs because to_dict uses datetime_to_str # This behavior occurs because to_dict uses datetime_to_str
# which strips microseconds # which does not preserve microseconds
if last_changed.microsecond: self.last_changed = util.strip_microseconds(
self.last_changed = last_changed - dt.timedelta( last_changed or dt.datetime.now())
microseconds=last_changed.microsecond)
else:
self.last_changed = last_changed
def copy(self): def copy(self):
""" Creates a copy of itself. """ """ Creates a copy of itself. """
@ -504,6 +500,19 @@ class StateMachine(object):
# Make a copy so people won't mutate the state # Make a copy so people won't mutate the state
return state.copy() if state else None return state.copy() if state else None
def get_since(self, point_in_time):
"""
Returns all states that have been changed since point_in_time.
Note: States keep track of last_changed -without- microseconds.
Therefore your point_in_time will also be stripped of microseconds.
"""
point_in_time = util.strip_microseconds(point_in_time)
with self._lock:
return [state for state in self._states.values()
if state.last_changed >= point_in_time]
def is_state(self, entity_id, state): def is_state(self, entity_id, state):
""" Returns True if entity exists and is specified state. """ """ Returns True if entity exists and is specified state. """
return (entity_id in self._states and return (entity_id in self._states and

View File

@ -39,7 +39,6 @@ ATTR_SERVICE = "service"
# Data for a SERVICE_EXECUTED event # Data for a SERVICE_EXECUTED event
ATTR_SERVICE_CALL_ID = "service_call_id" ATTR_SERVICE_CALL_ID = "service_call_id"
ATTR_RESULT = "result"
# Contains one string or a list of strings, each being an entity id # Contains one string or a list of strings, each being an entity id
ATTR_ENTITY_ID = 'entity_id' ATTR_ENTITY_ID = 'entity_id'

View File

@ -1,6 +1,8 @@
""" """
Helper methods for components within Home Assistant. Helper methods for components within Home Assistant.
""" """
from datetime import datetime
from homeassistant import NoEntitySpecifiedError from homeassistant import NoEntitySpecifiedError
from homeassistant.loader import get_component from homeassistant.loader import get_component
@ -33,6 +35,25 @@ def extract_entity_ids(hass, service):
return entity_ids return entity_ids
# pylint: disable=too-few-public-methods, attribute-defined-outside-init
class TrackStates(object):
"""
Records the time when the with-block is entered. Will add all states
that have changed since the start time to the return list when with-block
is exited.
"""
def __init__(self, hass):
self.hass = hass
self.states = []
def __enter__(self):
self.now = datetime.now()
return self.states
def __exit__(self, exc_type, exc_value, traceback):
self.states.extend(self.hass.states.get_since(self.now))
def validate_config(config, items, logger): def validate_config(config, items, logger):
""" """
Validates if all items are available in the configuration. Validates if all items are available in the configuration.

View File

@ -8,7 +8,7 @@ import collections
from itertools import chain from itertools import chain
import threading import threading
import queue import queue
from datetime import datetime from datetime import datetime, timedelta
import re import re
import enum import enum
import socket import socket
@ -57,6 +57,14 @@ def str_to_datetime(dt_str):
return None return None
def strip_microseconds(dattim):
""" Returns a copy of dattime object but with microsecond set to 0. """
if dattim.microsecond:
return dattim - timedelta(microseconds=dattim.microsecond)
else:
return dattim
def split_entity_id(entity_id): def split_entity_id(entity_id):
""" Splits a state entity_id into domain, object_id. """ """ Splits a state entity_id into domain, object_id. """
return entity_id.split(".", 1) return entity_id.split(".", 1)