mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 07:07:28 +00:00
Add and restore context in recorder (#15859)
This commit is contained in:
parent
da916d7b27
commit
9512bb9587
@ -114,6 +114,27 @@ def _drop_index(engine, table_name, index_name):
|
|||||||
"critical operation.", index_name, table_name)
|
"critical operation.", index_name, table_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_columns(engine, table_name, columns_def):
|
||||||
|
"""Add columns to a table."""
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
columns_def = ['ADD COLUMN {}'.format(col_def) for col_def in columns_def]
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine.execute(text("ALTER TABLE {table} {columns_def}".format(
|
||||||
|
table=table_name,
|
||||||
|
columns_def=', '.join(columns_def))))
|
||||||
|
return
|
||||||
|
except SQLAlchemyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for column_def in columns_def:
|
||||||
|
engine.execute(text("ALTER TABLE {table} {column_def}".format(
|
||||||
|
table=table_name,
|
||||||
|
column_def=column_def)))
|
||||||
|
|
||||||
|
|
||||||
def _apply_update(engine, new_version, old_version):
|
def _apply_update(engine, new_version, old_version):
|
||||||
"""Perform operations to bring schema up to date."""
|
"""Perform operations to bring schema up to date."""
|
||||||
if new_version == 1:
|
if new_version == 1:
|
||||||
@ -146,6 +167,19 @@ def _apply_update(engine, new_version, old_version):
|
|||||||
elif new_version == 5:
|
elif new_version == 5:
|
||||||
# Create supporting index for States.event_id foreign key
|
# Create supporting index for States.event_id foreign key
|
||||||
_create_index(engine, "states", "ix_states_event_id")
|
_create_index(engine, "states", "ix_states_event_id")
|
||||||
|
elif new_version == 6:
|
||||||
|
_add_columns(engine, "events", [
|
||||||
|
'context_id CHARACTER(36)',
|
||||||
|
'context_user_id CHARACTER(36)',
|
||||||
|
])
|
||||||
|
_create_index(engine, "events", "ix_events_context_id")
|
||||||
|
_create_index(engine, "events", "ix_events_context_user_id")
|
||||||
|
_add_columns(engine, "states", [
|
||||||
|
'context_id CHARACTER(36)',
|
||||||
|
'context_user_id CHARACTER(36)',
|
||||||
|
])
|
||||||
|
_create_index(engine, "states", "ix_states_context_id")
|
||||||
|
_create_index(engine, "states", "ix_states_context_user_id")
|
||||||
else:
|
else:
|
||||||
raise ValueError("No schema migration defined for version {}"
|
raise ValueError("No schema migration defined for version {}"
|
||||||
.format(new_version))
|
.format(new_version))
|
||||||
|
@ -9,14 +9,15 @@ from sqlalchemy import (
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
from homeassistant.core import Event, EventOrigin, State, split_entity_id
|
from homeassistant.core import (
|
||||||
|
Context, Event, EventOrigin, State, split_entity_id)
|
||||||
from homeassistant.remote import JSONEncoder
|
from homeassistant.remote import JSONEncoder
|
||||||
|
|
||||||
# SQLAlchemy Schema
|
# SQLAlchemy Schema
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
SCHEMA_VERSION = 5
|
SCHEMA_VERSION = 6
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -31,6 +32,8 @@ class Events(Base): # type: ignore
|
|||||||
origin = Column(String(32))
|
origin = Column(String(32))
|
||||||
time_fired = Column(DateTime(timezone=True), index=True)
|
time_fired = Column(DateTime(timezone=True), index=True)
|
||||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||||
|
context_id = Column(String(36), index=True)
|
||||||
|
context_user_id = Column(String(36), index=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_event(event):
|
def from_event(event):
|
||||||
@ -38,16 +41,23 @@ class Events(Base): # type: ignore
|
|||||||
return Events(event_type=event.event_type,
|
return Events(event_type=event.event_type,
|
||||||
event_data=json.dumps(event.data, cls=JSONEncoder),
|
event_data=json.dumps(event.data, cls=JSONEncoder),
|
||||||
origin=str(event.origin),
|
origin=str(event.origin),
|
||||||
time_fired=event.time_fired)
|
time_fired=event.time_fired,
|
||||||
|
context_id=event.context.id,
|
||||||
|
context_user_id=event.context.user_id)
|
||||||
|
|
||||||
def to_native(self):
|
def to_native(self):
|
||||||
"""Convert to a natve HA Event."""
|
"""Convert to a natve HA Event."""
|
||||||
|
context = Context(
|
||||||
|
id=self.context_id,
|
||||||
|
user_id=self.context_user_id
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return Event(
|
return Event(
|
||||||
self.event_type,
|
self.event_type,
|
||||||
json.loads(self.event_data),
|
json.loads(self.event_data),
|
||||||
EventOrigin(self.origin),
|
EventOrigin(self.origin),
|
||||||
_process_timestamp(self.time_fired)
|
_process_timestamp(self.time_fired),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# When json.loads fails
|
# When json.loads fails
|
||||||
@ -69,6 +79,8 @@ class States(Base): # type: ignore
|
|||||||
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow,
|
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow,
|
||||||
index=True)
|
index=True)
|
||||||
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
created = Column(DateTime(timezone=True), default=datetime.utcnow)
|
||||||
|
context_id = Column(String(36), index=True)
|
||||||
|
context_user_id = Column(String(36), index=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
# Used for fetching the state of entities at a specific time
|
# Used for fetching the state of entities at a specific time
|
||||||
@ -82,7 +94,11 @@ class States(Base): # type: ignore
|
|||||||
entity_id = event.data['entity_id']
|
entity_id = event.data['entity_id']
|
||||||
state = event.data.get('new_state')
|
state = event.data.get('new_state')
|
||||||
|
|
||||||
dbstate = States(entity_id=entity_id)
|
dbstate = States(
|
||||||
|
entity_id=entity_id,
|
||||||
|
context_id=event.context.id,
|
||||||
|
context_user_id=event.context.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# State got deleted
|
# State got deleted
|
||||||
if state is None:
|
if state is None:
|
||||||
@ -103,12 +119,17 @@ class States(Base): # type: ignore
|
|||||||
|
|
||||||
def to_native(self):
|
def to_native(self):
|
||||||
"""Convert to an HA state object."""
|
"""Convert to an HA state object."""
|
||||||
|
context = Context(
|
||||||
|
id=self.context_id,
|
||||||
|
user_id=self.context_user_id
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return State(
|
return State(
|
||||||
self.entity_id, self.state,
|
self.entity_id, self.state,
|
||||||
json.loads(self.attributes),
|
json.loads(self.attributes),
|
||||||
_process_timestamp(self.last_changed),
|
_process_timestamp(self.last_changed),
|
||||||
_process_timestamp(self.last_updated)
|
_process_timestamp(self.last_updated),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# When json.loads fails
|
# When json.loads fails
|
||||||
|
@ -423,7 +423,8 @@ class Event:
|
|||||||
self.event_type == other.event_type and
|
self.event_type == other.event_type and
|
||||||
self.data == other.data and
|
self.data == other.data and
|
||||||
self.origin == other.origin and
|
self.origin == other.origin and
|
||||||
self.time_fired == other.time_fired)
|
self.time_fired == other.time_fired and
|
||||||
|
self.context == other.context)
|
||||||
|
|
||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
@ -695,7 +696,8 @@ class State:
|
|||||||
return (self.__class__ == other.__class__ and # type: ignore
|
return (self.__class__ == other.__class__ and # type: ignore
|
||||||
self.entity_id == other.entity_id and
|
self.entity_id == other.entity_id and
|
||||||
self.state == other.state and
|
self.state == other.state and
|
||||||
self.attributes == other.attributes)
|
self.attributes == other.attributes and
|
||||||
|
self.context == other.context)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return the representation of the states."""
|
"""Return the representation of the states."""
|
||||||
|
@ -266,7 +266,7 @@ def mock_state_change_event(hass, new_state, old_state=None):
|
|||||||
if old_state:
|
if old_state:
|
||||||
event_data['old_state'] = old_state
|
event_data['old_state'] = old_state
|
||||||
|
|
||||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
|
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -60,7 +60,7 @@ class TestStates(unittest.TestCase):
|
|||||||
'entity_id': 'sensor.temperature',
|
'entity_id': 'sensor.temperature',
|
||||||
'old_state': None,
|
'old_state': None,
|
||||||
'new_state': state,
|
'new_state': state,
|
||||||
})
|
}, context=state.context)
|
||||||
assert state == States.from_event(event).to_native()
|
assert state == States.from_event(event).to_native()
|
||||||
|
|
||||||
def test_from_event_to_delete_state(self):
|
def test_from_event_to_delete_state(self):
|
||||||
|
@ -83,9 +83,10 @@ class TestComponentHistory(unittest.TestCase):
|
|||||||
self.wait_recording_done()
|
self.wait_recording_done()
|
||||||
|
|
||||||
# Get states returns everything before POINT
|
# Get states returns everything before POINT
|
||||||
self.assertEqual(states,
|
for state1, state2 in zip(
|
||||||
sorted(history.get_states(self.hass, future),
|
states, sorted(history.get_states(self.hass, future),
|
||||||
key=lambda state: state.entity_id))
|
key=lambda state: state.entity_id)):
|
||||||
|
assert state1 == state2
|
||||||
|
|
||||||
# Test get_state here because we have a DB setup
|
# Test get_state here because we have a DB setup
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -246,8 +246,9 @@ class TestEvent(unittest.TestCase):
|
|||||||
"""Test events."""
|
"""Test events."""
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
data = {'some': 'attr'}
|
data = {'some': 'attr'}
|
||||||
|
context = ha.Context()
|
||||||
event1, event2 = [
|
event1, event2 = [
|
||||||
ha.Event('some_type', data, time_fired=now)
|
ha.Event('some_type', data, time_fired=now, context=context)
|
||||||
for _ in range(2)
|
for _ in range(2)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user