mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Allow chaining contexts (#21028)
* Allow chaining contexts * Add stubbed out migration
This commit is contained in:
parent
b39846fb6b
commit
52f337ef00
@ -7,7 +7,7 @@ import logging
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant.core import CoreState
|
||||
from homeassistant.core import CoreState, Context
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
|
||||
@ -280,15 +280,21 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if skip_condition or self._cond_func(variables):
|
||||
self.async_set_context(context)
|
||||
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
|
||||
ATTR_NAME: self._name,
|
||||
ATTR_ENTITY_ID: self.entity_id,
|
||||
}, context=context)
|
||||
await self._async_action(self.entity_id, variables, context)
|
||||
self._last_triggered = utcnow()
|
||||
await self.async_update_ha_state()
|
||||
if not skip_condition and not self._cond_func(variables):
|
||||
return
|
||||
|
||||
# Create a new context referring to the old context.
|
||||
parent_id = None if context is None else context.id
|
||||
trigger_context = Context(parent_id=parent_id)
|
||||
|
||||
self.async_set_context(trigger_context)
|
||||
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
|
||||
ATTR_NAME: self._name,
|
||||
ATTR_ENTITY_ID: self.entity_id,
|
||||
}, context=trigger_context)
|
||||
await self._async_action(self.entity_id, variables, trigger_context)
|
||||
self._last_triggered = utcnow()
|
||||
await self.async_update_ha_state()
|
||||
|
||||
async def async_will_remove_from_hass(self):
|
||||
"""Remove listeners when removing automation from HASS."""
|
||||
|
@ -220,6 +220,15 @@ def _apply_update(engine, new_version, old_version):
|
||||
_create_index(engine, "states", "ix_states_context_user_id")
|
||||
elif new_version == 7:
|
||||
_create_index(engine, "states", "ix_states_entity_id")
|
||||
elif new_version == 8:
|
||||
# Pending migration, want to group a few.
|
||||
pass
|
||||
# _add_columns(engine, "events", [
|
||||
# 'context_parent_id CHARACTER(36)',
|
||||
# ])
|
||||
# _add_columns(engine, "states", [
|
||||
# 'context_parent_id CHARACTER(36)',
|
||||
# ])
|
||||
else:
|
||||
raise ValueError("No schema migration defined for version {}"
|
||||
.format(new_version))
|
||||
|
@ -34,16 +34,20 @@ class Events(Base): # type: ignore
|
||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
context_id = Column(String(36), index=True)
|
||||
context_user_id = Column(String(36), index=True)
|
||||
# context_parent_id = Column(String(36), index=True)
|
||||
|
||||
@staticmethod
|
||||
def from_event(event):
|
||||
"""Create an event database object from a native event."""
|
||||
return Events(event_type=event.event_type,
|
||||
event_data=json.dumps(event.data, cls=JSONEncoder),
|
||||
origin=str(event.origin),
|
||||
time_fired=event.time_fired,
|
||||
context_id=event.context.id,
|
||||
context_user_id=event.context.user_id)
|
||||
return Events(
|
||||
event_type=event.event_type,
|
||||
event_data=json.dumps(event.data, cls=JSONEncoder),
|
||||
origin=str(event.origin),
|
||||
time_fired=event.time_fired,
|
||||
context_id=event.context.id,
|
||||
context_user_id=event.context.user_id,
|
||||
# context_parent_id=event.context.parent_id,
|
||||
)
|
||||
|
||||
def to_native(self):
|
||||
"""Convert to a natve HA Event."""
|
||||
@ -81,6 +85,7 @@ class States(Base): # type: ignore
|
||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
context_id = Column(String(36), index=True)
|
||||
context_user_id = Column(String(36), index=True)
|
||||
# context_parent_id = Column(String(36), index=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Used for fetching the state of entities at a specific time
|
||||
@ -99,6 +104,7 @@ class States(Base): # type: ignore
|
||||
entity_id=entity_id,
|
||||
context_id=event.context.id,
|
||||
context_user_id=event.context.user_id,
|
||||
# context_parent_id=event.context.parent_id,
|
||||
)
|
||||
|
||||
# State got deleted
|
||||
|
@ -409,6 +409,10 @@ class Context:
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parent_id = attr.ib(
|
||||
type=Optional[str],
|
||||
default=None
|
||||
)
|
||||
id = attr.ib(
|
||||
type=str,
|
||||
default=attr.Factory(lambda: uuid.uuid4().hex),
|
||||
@ -418,6 +422,7 @@ class Context:
|
||||
"""Return a dictionary representation of the context."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'parent_id': self.parent_id,
|
||||
'user_id': self.user_id,
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ async def test_if_fires_on_event(hass, calls):
|
||||
hass.bus.async_fire('test_event', context=context)
|
||||
await hass.async_block_till_done()
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
|
||||
await common.async_turn_off(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -68,7 +68,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
assert 'geo_location - geo_location.entity - hello - hello - test' == \
|
||||
calls[0].data['some']
|
||||
|
||||
@ -221,7 +221,7 @@ async def test_if_fires_on_zone_appear(hass, calls):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
assert 'geo_location - geo_location.entity - - hello - test' == \
|
||||
calls[0].data['some']
|
||||
|
||||
|
@ -369,38 +369,47 @@ async def test_shared_context(hass, calls):
|
||||
})
|
||||
|
||||
context = Context()
|
||||
automation_mock = Mock()
|
||||
first_automation_listener = Mock()
|
||||
event_mock = Mock()
|
||||
|
||||
hass.bus.async_listen('test_event2', automation_mock)
|
||||
hass.bus.async_listen('test_event2', first_automation_listener)
|
||||
hass.bus.async_listen(EVENT_AUTOMATION_TRIGGERED, event_mock)
|
||||
hass.bus.async_fire('test_event', context=context)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Ensure events was fired
|
||||
assert automation_mock.call_count == 1
|
||||
assert first_automation_listener.call_count == 1
|
||||
assert event_mock.call_count == 2
|
||||
|
||||
# Ensure context carries through the event
|
||||
args, kwargs = automation_mock.call_args
|
||||
assert args[0].context == context
|
||||
# Verify automation triggered evenet for 'hello' automation
|
||||
args, kwargs = event_mock.call_args_list[0]
|
||||
first_trigger_context = args[0].context
|
||||
assert first_trigger_context.parent_id == context.id
|
||||
# Ensure event data has all attributes set
|
||||
assert args[0].data.get(ATTR_NAME) is not None
|
||||
assert args[0].data.get(ATTR_ENTITY_ID) is not None
|
||||
|
||||
for call in event_mock.call_args_list:
|
||||
args, kwargs = call
|
||||
assert args[0].context == context
|
||||
# Ensure event data has all attributes set
|
||||
assert args[0].data.get(ATTR_NAME) is not None
|
||||
assert args[0].data.get(ATTR_ENTITY_ID) is not None
|
||||
# Ensure context set correctly for event fired by 'hello' automation
|
||||
args, kwargs = first_automation_listener.call_args
|
||||
assert args[0].context is first_trigger_context
|
||||
|
||||
# Ensure the automation state shares the same context
|
||||
# Ensure the 'hello' automation state has the right context
|
||||
state = hass.states.get('automation.hello')
|
||||
assert state is not None
|
||||
assert state.context == context
|
||||
assert state.context is first_trigger_context
|
||||
|
||||
# Verify automation triggered evenet for 'bye' automation
|
||||
args, kwargs = event_mock.call_args_list[1]
|
||||
second_trigger_context = args[0].context
|
||||
assert second_trigger_context.parent_id == first_trigger_context.id
|
||||
# Ensure event data has all attributes set
|
||||
assert args[0].data.get(ATTR_NAME) is not None
|
||||
assert args[0].data.get(ATTR_ENTITY_ID) is not None
|
||||
|
||||
# Ensure the service call from the second automation
|
||||
# shares the same context
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context == context
|
||||
assert calls[0].context is second_trigger_context
|
||||
|
||||
|
||||
async def test_services(hass, calls):
|
||||
|
@ -45,7 +45,7 @@ async def test_if_fires_on_entity_change_below(hass, calls):
|
||||
hass.states.async_set('test.entity', 9, context=context)
|
||||
await hass.async_block_till_done()
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
|
||||
# Set above 12 so the automation will fire again
|
||||
hass.states.async_set('test.entity', 12)
|
||||
@ -134,7 +134,7 @@ async def test_if_not_fires_on_entity_change_below_to_below(hass, calls):
|
||||
hass.states.async_set('test.entity', 9, context=context)
|
||||
await hass.async_block_till_done()
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
|
||||
# already below so should not fire again
|
||||
hass.states.async_set('test.entity', 5)
|
||||
|
@ -55,7 +55,7 @@ async def test_if_fires_on_entity_change(hass, calls):
|
||||
hass.states.async_set('test.entity', 'world', context=context)
|
||||
await hass.async_block_till_done()
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
assert 'state - test.entity - hello - world - None' == \
|
||||
calls[0].data['some']
|
||||
|
||||
|
@ -257,7 +257,7 @@ async def test_if_fires_on_change_with_template_advanced(hass, calls):
|
||||
hass.states.async_set('test.entity', 'world', context=context)
|
||||
await hass.async_block_till_done()
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
assert 'template - test.entity - hello - world' == \
|
||||
calls[0].data['some']
|
||||
|
||||
|
@ -66,7 +66,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert 1 == len(calls)
|
||||
assert calls[0].context is context
|
||||
assert calls[0].context.parent_id == context.id
|
||||
assert 'zone - test.entity - hello - hello - test' == \
|
||||
calls[0].data['some']
|
||||
|
||||
|
@ -310,6 +310,7 @@ class TestEvent(unittest.TestCase):
|
||||
'time_fired': now,
|
||||
'context': {
|
||||
'id': event.context.id,
|
||||
'parent_id': None,
|
||||
'user_id': event.context.user_id,
|
||||
},
|
||||
}
|
||||
@ -1076,3 +1077,16 @@ async def test_service_call_event_contains_original_data(hass):
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data['number'] == 23
|
||||
assert calls[0].context is context
|
||||
|
||||
|
||||
def test_context():
|
||||
"""Test context init."""
|
||||
c = ha.Context()
|
||||
assert c.user_id is None
|
||||
assert c.parent_id is None
|
||||
assert c.id is not None
|
||||
|
||||
c = ha.Context(23, 100)
|
||||
assert c.user_id == 23
|
||||
assert c.parent_id == 100
|
||||
assert c.id is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user