Avoid recording state_changed events in the events table (#71165)

* squash

fix mypy

* Update homeassistant/components/recorder/models.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* pass all columns

* fix commented out code

* reduce logbook query complexity

* merge

* comment

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2022-05-02 02:10:34 -05:00 committed by GitHub
parent 7026e5dd11
commit 5db014666c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 151 additions and 134 deletions

View File

@ -111,13 +111,28 @@ ALL_EVENT_TYPES = [
*ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED, *ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED,
] ]
EVENT_COLUMNS = [ EVENT_COLUMNS = [
Events.event_type, Events.event_type.label("event_type"),
Events.event_data, Events.event_data.label("event_data"),
Events.time_fired, Events.time_fired.label("time_fired"),
Events.context_id, Events.context_id.label("context_id"),
Events.context_user_id, Events.context_user_id.label("context_user_id"),
Events.context_parent_id, Events.context_parent_id.label("context_parent_id"),
]
STATE_COLUMNS = [
States.state.label("state"),
States.entity_id.label("entity_id"),
States.attributes.label("attributes"),
StateAttributes.shared_attrs.label("shared_attrs"),
]
EMPTY_STATE_COLUMNS = [
literal(value=None, type_=sqlalchemy.String).label("state"),
literal(value=None, type_=sqlalchemy.String).label("entity_id"),
literal(value=None, type_=sqlalchemy.Text).label("attributes"),
literal(value=None, type_=sqlalchemy.Text).label("shared_attrs"),
] ]
SCRIPT_AUTOMATION_EVENTS = {EVENT_AUTOMATION_TRIGGERED, EVENT_SCRIPT_STARTED} SCRIPT_AUTOMATION_EVENTS = {EVENT_AUTOMATION_TRIGGERED, EVENT_SCRIPT_STARTED}
@ -502,43 +517,47 @@ def _get_events(
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
old_state = aliased(States, name="old_state") old_state = aliased(States, name="old_state")
query: Query
query = _generate_events_query_without_states(session)
query = _apply_event_time_filter(query, start_day, end_day)
query = _apply_event_types_filter(
hass, query, ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED
)
if entity_ids is not None: if entity_ids is not None:
query = _generate_events_query_without_states(session)
query = _apply_event_time_filter(query, start_day, end_day)
query = _apply_event_types_filter(
hass, query, ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED
)
if entity_matches_only: if entity_matches_only:
# When entity_matches_only is provided, contexts and events that do not # When entity_matches_only is provided, contexts and events that do not
# contain the entity_ids are not included in the logbook response. # contain the entity_ids are not included in the logbook response.
query = _apply_event_entity_id_matchers(query, entity_ids) query = _apply_event_entity_id_matchers(query, entity_ids)
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id)) query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
query = query.union_all( query = query.union_all(
_generate_states_query( _generate_states_query(
session, start_day, end_day, old_state, entity_ids session, start_day, end_day, old_state, entity_ids
) )
) )
else: else:
query = _generate_events_query(session)
query = _apply_event_time_filter(query, start_day, end_day)
query = _apply_events_types_and_states_filter(
hass, query, old_state
).filter(
(States.last_updated == States.last_changed)
| (Events.event_type != EVENT_STATE_CHANGED)
)
if filters:
query = query.filter(
filters.entity_filter() | (Events.event_type != EVENT_STATE_CHANGED) # type: ignore[no-untyped-call]
)
if context_id is not None: if context_id is not None:
query = query.filter(Events.context_id == context_id) query = query.filter(Events.context_id == context_id)
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id)) query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
states_query = _generate_states_query(
session, start_day, end_day, old_state, entity_ids
)
if context_id is not None:
# Once all the old `state_changed` events
# are gone from the database this query can
# be simplified to filter only on States.context_id == context_id
states_query = states_query.outerjoin(
Events, (States.event_id == Events.event_id)
)
states_query = states_query.filter(
(States.context_id == context_id)
| (States.context_id.is_(None) & (Events.context_id == context_id))
)
if filters:
states_query = states_query.filter(filters.entity_filter()) # type: ignore[no-untyped-call]
query = query.union_all(states_query)
query = query.order_by(Events.time_fired) query = query.order_by(Events.time_fired)
return list( return list(
@ -546,36 +565,22 @@ def _get_events(
) )
def _generate_events_query(session: Session) -> Query:
return session.query(
*EVENT_COLUMNS,
EventData.shared_data,
States.state,
States.entity_id,
States.attributes,
StateAttributes.shared_attrs,
)
def _generate_events_query_without_data(session: Session) -> Query: def _generate_events_query_without_data(session: Session) -> Query:
return session.query( return session.query(
*EVENT_COLUMNS, literal(value=EVENT_STATE_CHANGED, type_=sqlalchemy.String).label("event_type"),
literal(value=None, type_=sqlalchemy.Text).label("event_data"),
States.last_changed.label("time_fired"),
States.context_id.label("context_id"),
States.context_user_id.label("context_user_id"),
States.context_parent_id.label("context_parent_id"),
literal(value=None, type_=sqlalchemy.Text).label("shared_data"), literal(value=None, type_=sqlalchemy.Text).label("shared_data"),
States.state, *STATE_COLUMNS,
States.entity_id,
States.attributes,
StateAttributes.shared_attrs,
) )
def _generate_events_query_without_states(session: Session) -> Query: def _generate_events_query_without_states(session: Session) -> Query:
return session.query( return session.query(
*EVENT_COLUMNS, *EVENT_COLUMNS, EventData.shared_data.label("shared_data"), *EMPTY_STATE_COLUMNS
EventData.shared_data,
literal(value=None, type_=sqlalchemy.String).label("state"),
literal(value=None, type_=sqlalchemy.String).label("entity_id"),
literal(value=None, type_=sqlalchemy.Text).label("attributes"),
literal(value=None, type_=sqlalchemy.Text).label("shared_attrs"),
) )
@ -584,41 +589,19 @@ def _generate_states_query(
start_day: dt, start_day: dt,
end_day: dt, end_day: dt,
old_state: States, old_state: States,
entity_ids: Iterable[str], entity_ids: Iterable[str] | None,
) -> Query: ) -> Query:
return ( query = (
_generate_events_query_without_data(session) _generate_events_query_without_data(session)
.outerjoin(Events, (States.event_id == Events.event_id))
.outerjoin(old_state, (States.old_state_id == old_state.state_id)) .outerjoin(old_state, (States.old_state_id == old_state.state_id))
.filter(_missing_state_matcher(old_state)) .filter(_missing_state_matcher(old_state))
.filter(_not_continuous_entity_matcher()) .filter(_not_continuous_entity_matcher())
.filter((States.last_updated > start_day) & (States.last_updated < end_day)) .filter((States.last_updated > start_day) & (States.last_updated < end_day))
.filter( .filter(States.last_updated == States.last_changed)
(States.last_updated == States.last_changed)
& States.entity_id.in_(entity_ids)
)
.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
)
) )
if entity_ids:
query = query.filter(States.entity_id.in_(entity_ids))
def _apply_events_types_and_states_filter( return query.outerjoin(
hass: HomeAssistant, query: Query, old_state: States
) -> Query:
events_query = (
query.outerjoin(States, (Events.event_id == States.event_id))
.outerjoin(old_state, (States.old_state_id == old_state.state_id))
.filter(
(Events.event_type != EVENT_STATE_CHANGED)
| _missing_state_matcher(old_state)
)
.filter(
(Events.event_type != EVENT_STATE_CHANGED)
| _not_continuous_entity_matcher()
)
)
return _apply_event_types_filter(hass, events_query, ALL_EVENT_TYPES).outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )

View File

@ -1223,8 +1223,8 @@ class Recorder(threading.Thread):
] = dbevent_data ] = dbevent_data
self.event_session.add(dbevent_data) self.event_session.add(dbevent_data)
self.event_session.add(dbevent)
if event.event_type != EVENT_STATE_CHANGED: if event.event_type != EVENT_STATE_CHANGED:
self.event_session.add(dbevent)
return return
try: try:
@ -1272,7 +1272,6 @@ class Recorder(threading.Thread):
self._pending_expunge.append(dbstate) self._pending_expunge.append(dbstate)
else: else:
dbstate.state = None dbstate.state = None
dbstate.event = dbevent
self.event_session.add(dbstate) self.event_session.add(dbstate)
def _handle_database_error(self, err: Exception) -> bool: def _handle_database_error(self, err: Exception) -> bool:

View File

@ -442,7 +442,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# and we would have to move to something like # and we would have to move to something like
# sqlalchemy alembic to make that work # sqlalchemy alembic to make that work
# #
_drop_index(instance, "states", "ix_states_context_id") # no longer dropping ix_states_context_id since its recreated in 28
_drop_index(instance, "states", "ix_states_context_user_id") _drop_index(instance, "states", "ix_states_context_user_id")
# This index won't be there if they were not running # This index won't be there if they were not running
# nightly but we don't treat that as a critical issue # nightly but we don't treat that as a critical issue
@ -652,6 +652,24 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
elif new_version == 27: elif new_version == 27:
_add_columns(instance, "events", [f"data_id {big_int}"]) _add_columns(instance, "events", [f"data_id {big_int}"])
_create_index(instance, "events", "ix_events_data_id") _create_index(instance, "events", "ix_events_data_id")
elif new_version == 28:
_add_columns(instance, "events", ["origin_idx INTEGER"])
# We never use the user_id or parent_id index
_drop_index(instance, "events", "ix_events_context_user_id")
_drop_index(instance, "events", "ix_events_context_parent_id")
_add_columns(
instance,
"states",
[
"origin_idx INTEGER",
"context_id VARCHAR(36)",
"context_user_id VARCHAR(36)",
"context_parent_id VARCHAR(36)",
],
)
_create_index(instance, "states", "ix_states_context_id")
# Once there are no longer any state_changed events
# in the events table we can drop the index on states.event_id
else: else:
raise ValueError(f"No schema migration defined for version {new_version}") raise ValueError(f"No schema migration defined for version {new_version}")

View File

@ -17,6 +17,7 @@ from sqlalchemy import (
Identity, Identity,
Index, Index,
Integer, Integer,
SmallInteger,
String, String,
Text, Text,
distinct, distinct,
@ -43,7 +44,7 @@ from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
# pylint: disable=invalid-name # pylint: disable=invalid-name
Base = declarative_base() Base = declarative_base()
SCHEMA_VERSION = 27 SCHEMA_VERSION = 28
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -86,6 +87,8 @@ DOUBLE_TYPE = (
.with_variant(oracle.DOUBLE_PRECISION(), "oracle") .with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
) )
EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote]
EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)}
class Events(Base): # type: ignore[misc,valid-type] class Events(Base): # type: ignore[misc,valid-type]
@ -98,14 +101,15 @@ class Events(Base): # type: ignore[misc,valid-type]
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
) )
__tablename__ = TABLE_EVENTS __tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True) event_id = Column(Integer, Identity(), primary_key=True) # no longer used
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used
origin_idx = Column(SmallInteger)
time_fired = Column(DATETIME_TYPE, index=True) time_fired = Column(DATETIME_TYPE, index=True)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True) data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True)
event_data_rel = relationship("EventData") event_data_rel = relationship("EventData")
@ -114,7 +118,7 @@ class Events(Base): # type: ignore[misc,valid-type]
return ( return (
f"<recorder.Events(" f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', " f"id={self.event_id}, type='{self.event_type}', "
f"origin='{self.origin}', time_fired='{self.time_fired}'" f"origin_idx='{self.origin_idx}', time_fired='{self.time_fired}'"
f", data_id={self.data_id})>" f", data_id={self.data_id})>"
) )
@ -124,7 +128,7 @@ class Events(Base): # type: ignore[misc,valid-type]
return Events( return Events(
event_type=event.event_type, event_type=event.event_type,
event_data=None, event_data=None,
origin=str(event.origin.value), origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin),
time_fired=event.time_fired, time_fired=event.time_fired,
context_id=event.context.id, context_id=event.context.id,
context_user_id=event.context.user_id, context_user_id=event.context.user_id,
@ -142,7 +146,9 @@ class Events(Base): # type: ignore[misc,valid-type]
return Event( return Event(
self.event_type, self.event_type,
json.loads(self.event_data) if self.event_data else {}, json.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin), EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
process_timestamp(self.time_fired), process_timestamp(self.time_fired),
context=context, context=context,
) )
@ -222,7 +228,10 @@ class States(Base): # type: ignore[misc,valid-type]
attributes_id = Column( attributes_id = Column(
Integer, ForeignKey("state_attributes.attributes_id"), index=True Integer, ForeignKey("state_attributes.attributes_id"), index=True
) )
event = relationship("Events", uselist=False) context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID))
origin_idx = Column(SmallInteger) # 0 is local, 1 is remote
old_state = relationship("States", remote_side=[state_id]) old_state = relationship("States", remote_side=[state_id])
state_attributes = relationship("StateAttributes") state_attributes = relationship("StateAttributes")
@ -242,7 +251,14 @@ class States(Base): # type: ignore[misc,valid-type]
"""Create object from a state_changed event.""" """Create object from a state_changed event."""
entity_id = event.data["entity_id"] entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state") state: State | None = event.data.get("new_state")
dbstate = States(entity_id=entity_id, attributes=None) dbstate = States(
entity_id=entity_id,
attributes=None,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin),
)
# None state means the state was removed from the state machine # None state means the state was removed from the state machine
if state is None: if state is None:
@ -258,6 +274,11 @@ class States(Base): # type: ignore[misc,valid-type]
def to_native(self, validate_entity_id: bool = True) -> State | None: def to_native(self, validate_entity_id: bool = True) -> State | None:
"""Convert to an HA state object.""" """Convert to an HA state object."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try: try:
return State( return State(
self.entity_id, self.entity_id,
@ -267,9 +288,7 @@ class States(Base): # type: ignore[misc,valid-type]
json.loads(self.attributes) if self.attributes else {}, json.loads(self.attributes) if self.attributes else {},
process_timestamp(self.last_changed), process_timestamp(self.last_changed),
process_timestamp(self.last_updated), process_timestamp(self.last_updated),
# Join the events table on event_id to get the context instead context=context,
# as it will always be there for state_changed events
context=Context(id=None), # type: ignore[arg-type]
validate_entity_id=validate_entity_id, validate_entity_id=validate_entity_id,
) )
except ValueError: except ValueError:

View File

@ -83,7 +83,7 @@ def purge_old_data(
if short_term_statistics: if short_term_statistics:
_purge_short_term_statistics(session, short_term_statistics) _purge_short_term_statistics(session, short_term_statistics)
if event_ids or statistics_runs or short_term_statistics: if state_ids or event_ids or statistics_runs or short_term_statistics:
# Return false, as we might not be done yet. # Return false, as we might not be done yet.
_LOGGER.debug("Purging hasn't fully completed yet") _LOGGER.debug("Purging hasn't fully completed yet")
return False return False
@ -103,27 +103,31 @@ def _select_event_state_attributes_ids_data_ids_to_purge(
) -> tuple[set[int], set[int], set[int], set[int]]: ) -> tuple[set[int], set[int], set[int], set[int]]:
"""Return a list of event, state, and attribute ids to purge.""" """Return a list of event, state, and attribute ids to purge."""
events = ( events = (
session.query( session.query(Events.event_id, Events.data_id)
Events.event_id, Events.data_id, States.state_id, States.attributes_id
)
.outerjoin(States, Events.event_id == States.event_id)
.filter(Events.time_fired < purge_before) .filter(Events.time_fired < purge_before)
.limit(MAX_ROWS_TO_PURGE) .limit(MAX_ROWS_TO_PURGE)
.all() .all()
) )
_LOGGER.debug("Selected %s event ids to remove", len(events)) _LOGGER.debug("Selected %s event ids to remove", len(events))
states = (
session.query(States.state_id, States.attributes_id)
.filter(States.last_updated < purge_before)
.limit(MAX_ROWS_TO_PURGE)
.all()
)
_LOGGER.debug("Selected %s state ids to remove", len(states))
event_ids = set() event_ids = set()
state_ids = set() state_ids = set()
attributes_ids = set() attributes_ids = set()
data_ids = set() data_ids = set()
for event in events: for event in events:
event_ids.add(event.event_id) event_ids.add(event.event_id)
if event.state_id:
state_ids.add(event.state_id)
if event.attributes_id:
attributes_ids.add(event.attributes_id)
if event.data_id: if event.data_id:
data_ids.add(event.data_id) data_ids.add(event.data_id)
for state in states:
state_ids.add(state.state_id)
if state.attributes_id:
attributes_ids.add(state.attributes_id)
return event_ids, state_ids, attributes_ids, data_ids return event_ids, state_ids, attributes_ids, data_ids

View File

@ -49,7 +49,7 @@ from homeassistant.const import (
STATE_LOCKED, STATE_LOCKED,
STATE_UNLOCKED, STATE_UNLOCKED,
) )
from homeassistant.core import Context, CoreState, Event, HomeAssistant, callback from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.setup import async_setup_component, setup_component from homeassistant.setup import async_setup_component, setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -162,7 +162,7 @@ async def test_state_gets_saved_when_set_before_start_event(
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
db_states = list(session.query(States)) db_states = list(session.query(States))
assert len(db_states) == 1 assert len(db_states) == 1
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
async def test_saving_state(hass: HomeAssistant, recorder_mock): async def test_saving_state(hass: HomeAssistant, recorder_mock):
@ -182,9 +182,9 @@ async def test_saving_state(hass: HomeAssistant, recorder_mock):
state = db_state.to_native() state = db_state.to_native()
state.attributes = db_state_attributes.to_native() state.attributes = db_state_attributes.to_native()
assert len(db_states) == 1 assert len(db_states) == 1
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
assert state == _state_empty_context(hass, entity_id) assert state == _state_with_context(hass, entity_id)
async def test_saving_many_states( async def test_saving_many_states(
@ -210,7 +210,7 @@ async def test_saving_many_states(
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
db_states = list(session.query(States)) db_states = list(session.query(States))
assert len(db_states) == 6 assert len(db_states) == 6
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
async def test_saving_state_with_intermixed_time_changes( async def test_saving_state_with_intermixed_time_changes(
@ -234,7 +234,7 @@ async def test_saving_state_with_intermixed_time_changes(
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
db_states = list(session.query(States)) db_states = list(session.query(States))
assert len(db_states) == 2 assert len(db_states) == 2
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
def test_saving_state_with_exception(hass, hass_recorder, caplog): def test_saving_state_with_exception(hass, hass_recorder, caplog):
@ -411,7 +411,7 @@ def test_saving_state_with_commit_interval_zero(hass_recorder):
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
db_states = list(session.query(States)) db_states = list(session.query(States))
assert len(db_states) == 1 assert len(db_states) == 1
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
def _add_entities(hass, entity_ids): def _add_entities(hass, entity_ids):
@ -454,12 +454,10 @@ def _add_events(hass, events):
return events return events
def _state_empty_context(hass, entity_id): def _state_with_context(hass, entity_id):
# We don't restore context unless we need it by joining the # We don't restore context unless we need it by joining the
# events table on the event_id for state_changed events # events table on the event_id for state_changed events
state = hass.states.get(entity_id) return hass.states.get(entity_id)
state.context = Context(id=None)
return state
# pylint: disable=redefined-outer-name,invalid-name # pylint: disable=redefined-outer-name,invalid-name
@ -468,7 +466,7 @@ def test_saving_state_include_domains(hass_recorder):
hass = hass_recorder({"include": {"domains": "test2"}}) hass = hass_recorder({"include": {"domains": "test2"}})
states = _add_entities(hass, ["test.recorder", "test2.recorder"]) states = _add_entities(hass, ["test.recorder", "test2.recorder"])
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_with_context(hass, "test2.recorder") == states[0]
def test_saving_state_include_domains_globs(hass_recorder): def test_saving_state_include_domains_globs(hass_recorder):
@ -480,8 +478,8 @@ def test_saving_state_include_domains_globs(hass_recorder):
hass, ["test.recorder", "test2.recorder", "test3.included_entity"] hass, ["test.recorder", "test2.recorder", "test3.included_entity"]
) )
assert len(states) == 2 assert len(states) == 2
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_with_context(hass, "test2.recorder") == states[0]
assert _state_empty_context(hass, "test3.included_entity") == states[1] assert _state_with_context(hass, "test3.included_entity") == states[1]
def test_saving_state_incl_entities(hass_recorder): def test_saving_state_incl_entities(hass_recorder):
@ -489,7 +487,7 @@ def test_saving_state_incl_entities(hass_recorder):
hass = hass_recorder({"include": {"entities": "test2.recorder"}}) hass = hass_recorder({"include": {"entities": "test2.recorder"}})
states = _add_entities(hass, ["test.recorder", "test2.recorder"]) states = _add_entities(hass, ["test.recorder", "test2.recorder"])
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_with_context(hass, "test2.recorder") == states[0]
def test_saving_event_exclude_event_type(hass_recorder): def test_saving_event_exclude_event_type(hass_recorder):
@ -518,7 +516,7 @@ def test_saving_state_exclude_domains(hass_recorder):
hass = hass_recorder({"exclude": {"domains": "test"}}) hass = hass_recorder({"exclude": {"domains": "test"}})
states = _add_entities(hass, ["test.recorder", "test2.recorder"]) states = _add_entities(hass, ["test.recorder", "test2.recorder"])
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_with_context(hass, "test2.recorder") == states[0]
def test_saving_state_exclude_domains_globs(hass_recorder): def test_saving_state_exclude_domains_globs(hass_recorder):
@ -530,7 +528,7 @@ def test_saving_state_exclude_domains_globs(hass_recorder):
hass, ["test.recorder", "test2.recorder", "test2.excluded_entity"] hass, ["test.recorder", "test2.recorder", "test2.excluded_entity"]
) )
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_with_context(hass, "test2.recorder") == states[0]
def test_saving_state_exclude_entities(hass_recorder): def test_saving_state_exclude_entities(hass_recorder):
@ -538,7 +536,7 @@ def test_saving_state_exclude_entities(hass_recorder):
hass = hass_recorder({"exclude": {"entities": "test.recorder"}}) hass = hass_recorder({"exclude": {"entities": "test.recorder"}})
states = _add_entities(hass, ["test.recorder", "test2.recorder"]) states = _add_entities(hass, ["test.recorder", "test2.recorder"])
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test2.recorder") == states[0] assert _state_with_context(hass, "test2.recorder") == states[0]
def test_saving_state_exclude_domain_include_entity(hass_recorder): def test_saving_state_exclude_domain_include_entity(hass_recorder):
@ -571,8 +569,8 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder):
) )
states = _add_entities(hass, ["test.recorder", "test2.recorder", "test.ok"]) states = _add_entities(hass, ["test.recorder", "test2.recorder", "test.ok"])
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test.ok") == states[0] assert _state_with_context(hass, "test.ok") == states[0]
assert _state_empty_context(hass, "test.ok").state == "state2" assert _state_with_context(hass, "test.ok").state == "state2"
def test_saving_state_include_domain_glob_exclude_entity(hass_recorder): def test_saving_state_include_domain_glob_exclude_entity(hass_recorder):
@ -587,8 +585,8 @@ def test_saving_state_include_domain_glob_exclude_entity(hass_recorder):
hass, ["test.recorder", "test2.recorder", "test.ok", "test2.included_entity"] hass, ["test.recorder", "test2.recorder", "test.ok", "test2.included_entity"]
) )
assert len(states) == 1 assert len(states) == 1
assert _state_empty_context(hass, "test.ok") == states[0] assert _state_with_context(hass, "test.ok") == states[0]
assert _state_empty_context(hass, "test.ok").state == "state2" assert _state_with_context(hass, "test.ok").state == "state2"
def test_saving_state_and_removing_entity(hass, hass_recorder): def test_saving_state_and_removing_entity(hass, hass_recorder):
@ -1153,8 +1151,8 @@ def test_service_disable_states_not_recording(hass, hass_recorder):
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
db_states = list(session.query(States)) db_states = list(session.query(States))
assert len(db_states) == 1 assert len(db_states) == 1
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
assert db_states[0].to_native() == _state_empty_context(hass, "test.two") assert db_states[0].to_native() == _state_with_context(hass, "test.two")
def test_service_disable_run_information_recorded(tmpdir): def test_service_disable_run_information_recorded(tmpdir):
@ -1257,7 +1255,7 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
db_states = list(session.query(States)) db_states = list(session.query(States))
assert len(db_states) == 1 assert len(db_states) == 1
assert db_states[0].event_id > 0 assert db_states[0].event_id is None
return db_states[0].to_native() return db_states[0].to_native()
state = await hass.async_add_executor_job(_get_last_state) state = await hass.async_add_executor_job(_get_last_state)

View File

@ -39,9 +39,6 @@ def test_from_event_to_db_state():
{"entity_id": "sensor.temperature", "old_state": None, "new_state": state}, {"entity_id": "sensor.temperature", "old_state": None, "new_state": state},
context=state.context, context=state.context,
) )
# We don't restore context unless we need it by joining the
# events table on the event_id for state_changed events
state.context = ha.Context(id=None)
assert state == States.from_event(event).to_native() assert state == States.from_event(event).to_native()

View File

@ -64,7 +64,7 @@ async def test_purge_old_states(
assert state_attributes.count() == 3 assert state_attributes.count() == 3
events = session.query(Events).filter(Events.event_type == "state_changed") events = session.query(Events).filter(Events.event_type == "state_changed")
assert events.count() == 6 assert events.count() == 0
assert "test.recorder2" in instance._old_states assert "test.recorder2" in instance._old_states
purge_before = dt_util.utcnow() - timedelta(days=4) purge_before = dt_util.utcnow() - timedelta(days=4)
@ -108,7 +108,7 @@ async def test_purge_old_states(
assert states[5].old_state_id == states[4].state_id assert states[5].old_state_id == states[4].state_id
events = session.query(Events).filter(Events.event_type == "state_changed") events = session.query(Events).filter(Events.event_type == "state_changed")
assert events.count() == 6 assert events.count() == 0
assert "test.recorder2" in instance._old_states assert "test.recorder2" in instance._old_states
state_attributes = session.query(StateAttributes) state_attributes = session.query(StateAttributes)
@ -793,7 +793,6 @@ async def test_purge_filtered_states(
assert session.query(StateAttributes).count() == 11 assert session.query(StateAttributes).count() == 11
# Finally make sure we can delete them all except for the ones missing an event_id
service_data = {"keep_days": 0} service_data = {"keep_days": 0}
await hass.services.async_call( await hass.services.async_call(
recorder.DOMAIN, recorder.SERVICE_PURGE, service_data recorder.DOMAIN, recorder.SERVICE_PURGE, service_data
@ -805,8 +804,8 @@ async def test_purge_filtered_states(
remaining = list(session.query(States)) remaining = list(session.query(States))
for state in remaining: for state in remaining:
assert state.event_id is None assert state.event_id is None
assert len(remaining) == 3 assert len(remaining) == 0
assert session.query(StateAttributes).count() == 1 assert session.query(StateAttributes).count() == 0
@pytest.mark.parametrize("use_sqlite", (True, False), indirect=True) @pytest.mark.parametrize("use_sqlite", (True, False), indirect=True)