Improve recorder migration tests (#59075)

This commit is contained in:
Erik Montnemery 2021-11-05 04:21:38 +01:00 committed by GitHub
parent dc1edc98fc
commit 185f7beafc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1585 additions and 5 deletions

View File

@ -9,7 +9,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed, fire_time_changed
from tests.components.recorder import models_original
from tests.components.recorder import models_schema_0
DEFAULT_PURGE_TASKS = 3
@ -91,5 +91,5 @@ def create_engine_test(*args, **kwargs):
This simulates an existing db with the old schema.
"""
engine = create_engine(*args, **kwargs)
models_original.Base.metadata.create_all(engine)
models_schema_0.Base.metadata.create_all(engine)
return engine

View File

@ -0,0 +1,457 @@
"""Models for SQLAlchemy.
This file contains the model definitions for schema version 16,
used by Home Assistant Core 2021.6.0, which was the initial version
to include long term statistics.
It is used to test the schema migration logic.
"""
import json
import logging
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Identity,
Index,
Integer,
String,
Text,
distinct,
)
from sqlalchemy.dialects import mysql
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import Session
from homeassistant.const import (
MAX_LENGTH_EVENT_CONTEXT_ID,
MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_EVENT_ORIGIN,
MAX_LENGTH_STATE_DOMAIN,
MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSONEncoder
import homeassistant.util.dt as dt_util
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 16
_LOGGER = logging.getLogger(__name__)
DB_TIMEZONE = "+00:00"
TABLE_EVENTS = "events"
TABLE_STATES = "states"
TABLE_RECORDER_RUNS = "recorder_runs"
TABLE_SCHEMA_CHANGES = "schema_changes"
TABLE_STATISTICS = "statistics"
ALL_TABLES = [
TABLE_STATES,
TABLE_EVENTS,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
TABLE_STATISTICS,
]
DATETIME_TYPE = DateTime(timezone=True).with_variant(
mysql.DATETIME(timezone=True, fsp=6), "mysql"
)
class Events(Base): # type: ignore
"""Event history data."""
__table_args__ = {
"mysql_default_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
__tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True)
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN))
time_fired = Column(DATETIME_TYPE, index=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
__table_args__ = (
# Used for fetching events at a specific time
# see logbook
Index("ix_events_event_type_time_fired", "event_type", "time_fired"),
)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', data='{self.event_data}', "
f"origin='{self.origin}', time_fired='{self.time_fired}'"
f")>"
)
@staticmethod
def from_event(event, event_data=None):
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=event_data or json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin.value),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id=True):
"""Convert to a natve HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class States(Base): # type: ignore
"""State change history."""
__table_args__ = {
"mysql_default_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
__tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True)
domain = Column(String(MAX_LENGTH_STATE_DOMAIN))
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE))
attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
event_id = Column(
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
)
last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow)
last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True)
event = relationship("Events", uselist=False)
old_state = relationship("States", remote_side=[state_id])
__table_args__ = (
# Used for fetching the state of entities at a specific time
# (get_states in history.py)
Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"),
)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.States("
f"id={self.state_id}, domain='{self.domain}', entity_id='{self.entity_id}', "
f"state='{self.state}', event_id='{self.event_id}', "
f"last_updated='{self.last_updated.isoformat(sep=' ', timespec='seconds')}', "
f"old_state_id={self.old_state_id}"
f")>"
)
@staticmethod
def from_event(event):
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state = event.data.get("new_state")
dbstate = States(entity_id=entity_id)
# State got deleted
if state is None:
dbstate.state = ""
dbstate.domain = split_entity_id(entity_id)[0]
dbstate.attributes = "{}"
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.domain = state.domain
dbstate.state = state.state
dbstate.attributes = json.dumps(dict(state.attributes), cls=JSONEncoder)
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self, validate_entity_id=True):
"""Convert to an HA state object."""
try:
return State(
self.entity_id,
self.state,
json.loads(self.attributes),
process_timestamp(self.last_changed),
process_timestamp(self.last_updated),
# Join the events table on event_id to get the context instead
# as it will always be there for state_changed events
context=Context(id=None),
validate_entity_id=validate_entity_id,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class Statistics(Base): # type: ignore
"""Statistics."""
__table_args__ = {
"mysql_default_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
}
__tablename__ = TABLE_STATISTICS
id = Column(Integer, primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
source = Column(String(32))
statistic_id = Column(String(255))
start = Column(DATETIME_TYPE, index=True)
mean = Column(Float())
min = Column(Float())
max = Column(Float())
last_reset = Column(DATETIME_TYPE)
state = Column(Float())
sum = Column(Float())
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_statistic_id_start", "statistic_id", "start"),
)
@staticmethod
def from_stats(source, statistic_id, start, stats):
"""Create object from a statistics."""
return Statistics(
source=source,
statistic_id=statistic_id,
start=start,
**stats,
)
class RecorderRuns(Base): # type: ignore
"""Representation of recorder run."""
__tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), default=dt_util.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=dt_util.utcnow)
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
end = (
f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None
)
return (
f"<recorder.RecorderRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f"end={end}, closed_incorrect={self.closed_incorrect}, "
f"created='{self.created.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def entity_ids(self, point_in_time=None):
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start
)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self, validate_entity_id=True):
"""Return self, native format is this model."""
return self
class SchemaChanges(Base): # type: ignore
"""Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer)
changed = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.SchemaChanges("
f"id={self.change_id}, schema_version={self.schema_version}, "
f"changed='{self.changed.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def process_timestamp(ts):
"""Process a timestamp into datetime object."""
if ts is None:
return None
if ts.tzinfo is None:
return ts.replace(tzinfo=dt_util.UTC)
return dt_util.as_utc(ts)
def process_timestamp_to_utc_isoformat(ts):
"""Process a timestamp into UTC isotime."""
if ts is None:
return None
if ts.tzinfo == dt_util.UTC:
return ts.isoformat()
if ts.tzinfo is None:
return f"{ts.isoformat()}{DB_TIMEZONE}"
return ts.astimezone(dt_util.UTC).isoformat()
class LazyState(State):
"""A lazy version of core State."""
__slots__ = [
"_row",
"entity_id",
"state",
"_attributes",
"_last_changed",
"_last_updated",
"_context",
]
def __init__(self, row): # pylint: disable=super-init-not-called
"""Init the lazy state."""
self._row = row
self.entity_id = self._row.entity_id
self.state = self._row.state or ""
self._attributes = None
self._last_changed = None
self._last_updated = None
self._context = None
@property # type: ignore
def attributes(self):
"""State attributes."""
if not self._attributes:
try:
self._attributes = json.loads(self._row.attributes)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self._row)
self._attributes = {}
return self._attributes
@attributes.setter
def attributes(self, value):
"""Set attributes."""
self._attributes = value
@property # type: ignore
def context(self):
"""State context."""
if not self._context:
self._context = Context(id=None)
return self._context
@context.setter
def context(self, value):
"""Set context."""
self._context = value
@property # type: ignore
def last_changed(self):
"""Last changed datetime."""
if not self._last_changed:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value):
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore
def last_updated(self):
"""Last updated datetime."""
if not self._last_updated:
self._last_updated = process_timestamp(self._row.last_updated)
return self._last_updated
@last_updated.setter
def last_updated(self, value):
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self):
"""Return a dict representation of the LazyState.
Async friendly.
To be used for JSON serialization.
"""
if self._last_changed:
last_changed_isoformat = self._last_changed.isoformat()
else:
last_changed_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_changed
)
if self._last_updated:
last_updated_isoformat = self._last_updated.isoformat()
else:
last_updated_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_updated
)
return {
"entity_id": self.entity_id,
"state": self.state,
"attributes": self._attributes or self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
}
def __eq__(self, other):
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]
and self.entity_id == other.entity_id
and self.state == other.state
and self.attributes == other.attributes
)

View File

@ -0,0 +1,471 @@
"""Models for SQLAlchemy.
This file contains the model definitions for schema version 18,
used by Home Assistant Core 2021.7.0, which did a major refactoring
of long term statistics database models.
It is used to test the schema migration logic.
"""
import json
import logging
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Identity,
Index,
Integer,
String,
Text,
distinct,
)
from sqlalchemy.dialects import mysql
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import Session
from homeassistant.const import (
MAX_LENGTH_EVENT_CONTEXT_ID,
MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_EVENT_ORIGIN,
MAX_LENGTH_STATE_DOMAIN,
MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSONEncoder
import homeassistant.util.dt as dt_util
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 18
_LOGGER = logging.getLogger(__name__)
DB_TIMEZONE = "+00:00"
TABLE_EVENTS = "events"
TABLE_STATES = "states"
TABLE_RECORDER_RUNS = "recorder_runs"
TABLE_SCHEMA_CHANGES = "schema_changes"
TABLE_STATISTICS = "statistics"
TABLE_STATISTICS_META = "statistics_meta"
ALL_TABLES = [
TABLE_STATES,
TABLE_EVENTS,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
TABLE_STATISTICS,
TABLE_STATISTICS_META,
]
DATETIME_TYPE = DateTime(timezone=True).with_variant(
mysql.DATETIME(timezone=True, fsp=6), "mysql"
)
class Events(Base): # type: ignore
"""Event history data."""
__table_args__ = (
# Used for fetching events at a specific time
# see logbook
Index("ix_events_event_type_time_fired", "event_type", "time_fired"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True)
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN))
time_fired = Column(DATETIME_TYPE, index=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', data='{self.event_data}', "
f"origin='{self.origin}', time_fired='{self.time_fired}'"
f")>"
)
@staticmethod
def from_event(event, event_data=None):
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=event_data or json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin.value),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id=True):
"""Convert to a natve HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class States(Base): # type: ignore
"""State change history."""
__table_args__ = (
# Used for fetching the state of entities at a specific time
# (get_states in history.py)
Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True)
domain = Column(String(MAX_LENGTH_STATE_DOMAIN))
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE))
attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
event_id = Column(
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
)
last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow)
last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True)
event = relationship("Events", uselist=False)
old_state = relationship("States", remote_side=[state_id])
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.States("
f"id={self.state_id}, domain='{self.domain}', entity_id='{self.entity_id}', "
f"state='{self.state}', event_id='{self.event_id}', "
f"last_updated='{self.last_updated.isoformat(sep=' ', timespec='seconds')}', "
f"old_state_id={self.old_state_id}"
f")>"
)
@staticmethod
def from_event(event):
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state = event.data.get("new_state")
dbstate = States(entity_id=entity_id)
# State got deleted
if state is None:
dbstate.state = ""
dbstate.domain = split_entity_id(entity_id)[0]
dbstate.attributes = "{}"
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.domain = state.domain
dbstate.state = state.state
dbstate.attributes = json.dumps(dict(state.attributes), cls=JSONEncoder)
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self, validate_entity_id=True):
"""Convert to an HA state object."""
try:
return State(
self.entity_id,
self.state,
json.loads(self.attributes),
process_timestamp(self.last_changed),
process_timestamp(self.last_updated),
# Join the events table on event_id to get the context instead
# as it will always be there for state_changed events
context=Context(id=None),
validate_entity_id=validate_entity_id,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class Statistics(Base): # type: ignore
"""Statistics."""
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_statistic_id_start", "metadata_id", "start"),
)
__tablename__ = TABLE_STATISTICS
id = Column(Integer, primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
metadata_id = Column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True)
mean = Column(Float())
min = Column(Float())
max = Column(Float())
last_reset = Column(DATETIME_TYPE)
state = Column(Float())
sum = Column(Float())
@staticmethod
def from_stats(metadata_id, start, stats):
"""Create object from a statistics."""
return Statistics(
metadata_id=metadata_id,
start=start,
**stats,
)
class StatisticsMeta(Base): # type: ignore
"""Statistics meta data."""
__tablename__ = TABLE_STATISTICS_META
id = Column(Integer, primary_key=True)
statistic_id = Column(String(255), index=True)
source = Column(String(32))
unit_of_measurement = Column(String(255))
has_mean = Column(Boolean)
has_sum = Column(Boolean)
@staticmethod
def from_meta(source, statistic_id, unit_of_measurement, has_mean, has_sum):
"""Create object from meta data."""
return StatisticsMeta(
source=source,
statistic_id=statistic_id,
unit_of_measurement=unit_of_measurement,
has_mean=has_mean,
has_sum=has_sum,
)
class RecorderRuns(Base): # type: ignore
"""Representation of recorder run."""
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
__tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), default=dt_util.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
end = (
f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None
)
return (
f"<recorder.RecorderRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f"end={end}, closed_incorrect={self.closed_incorrect}, "
f"created='{self.created.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def entity_ids(self, point_in_time=None):
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start
)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self, validate_entity_id=True):
"""Return self, native format is this model."""
return self
class SchemaChanges(Base): # type: ignore
"""Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer)
changed = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.SchemaChanges("
f"id={self.change_id}, schema_version={self.schema_version}, "
f"changed='{self.changed.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def process_timestamp(ts):
"""Process a timestamp into datetime object."""
if ts is None:
return None
if ts.tzinfo is None:
return ts.replace(tzinfo=dt_util.UTC)
return dt_util.as_utc(ts)
def process_timestamp_to_utc_isoformat(ts):
"""Process a timestamp into UTC isotime."""
if ts is None:
return None
if ts.tzinfo == dt_util.UTC:
return ts.isoformat()
if ts.tzinfo is None:
return f"{ts.isoformat()}{DB_TIMEZONE}"
return ts.astimezone(dt_util.UTC).isoformat()
class LazyState(State):
"""A lazy version of core State."""
__slots__ = [
"_row",
"entity_id",
"state",
"_attributes",
"_last_changed",
"_last_updated",
"_context",
]
def __init__(self, row): # pylint: disable=super-init-not-called
"""Init the lazy state."""
self._row = row
self.entity_id = self._row.entity_id
self.state = self._row.state or ""
self._attributes = None
self._last_changed = None
self._last_updated = None
self._context = None
@property # type: ignore
def attributes(self):
"""State attributes."""
if not self._attributes:
try:
self._attributes = json.loads(self._row.attributes)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self._row)
self._attributes = {}
return self._attributes
@attributes.setter
def attributes(self, value):
"""Set attributes."""
self._attributes = value
@property # type: ignore
def context(self):
"""State context."""
if not self._context:
self._context = Context(id=None)
return self._context
@context.setter
def context(self, value):
"""Set context."""
self._context = value
@property # type: ignore
def last_changed(self):
"""Last changed datetime."""
if not self._last_changed:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value):
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore
def last_updated(self):
"""Last updated datetime."""
if not self._last_updated:
self._last_updated = process_timestamp(self._row.last_updated)
return self._last_updated
@last_updated.setter
def last_updated(self, value):
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self):
"""Return a dict representation of the LazyState.
Async friendly.
To be used for JSON serialization.
"""
if self._last_changed:
last_changed_isoformat = self._last_changed.isoformat()
else:
last_changed_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_changed
)
if self._last_updated:
last_updated_isoformat = self._last_updated.isoformat()
else:
last_updated_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_updated
)
return {
"entity_id": self.entity_id,
"state": self.state,
"attributes": self._attributes or self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
}
def __eq__(self, other):
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]
and self.entity_id == other.entity_id
and self.state == other.state
and self.attributes == other.attributes
)

View File

@ -0,0 +1,593 @@
"""Models for SQLAlchemy.
This file contains the model definitions for schema version 22,
used by Home Assistant Core 2021.10.0, which adds a table for
5-minute statistics.
It is used to test the schema migration logic.
"""
from __future__ import annotations
from collections.abc import Iterable
from datetime import datetime, timedelta
import json
import logging
from typing import TypedDict, overload
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Identity,
Index,
Integer,
String,
Text,
distinct,
)
from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session
from homeassistant.const import (
MAX_LENGTH_EVENT_CONTEXT_ID,
MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_EVENT_ORIGIN,
MAX_LENGTH_STATE_DOMAIN,
MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSONEncoder
import homeassistant.util.dt as dt_util
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 22
_LOGGER = logging.getLogger(__name__)
DB_TIMEZONE = "+00:00"
TABLE_EVENTS = "events"
TABLE_STATES = "states"
TABLE_RECORDER_RUNS = "recorder_runs"
TABLE_SCHEMA_CHANGES = "schema_changes"
TABLE_STATISTICS = "statistics"
TABLE_STATISTICS_META = "statistics_meta"
TABLE_STATISTICS_RUNS = "statistics_runs"
TABLE_STATISTICS_SHORT_TERM = "statistics_short_term"
ALL_TABLES = [
TABLE_STATES,
TABLE_EVENTS,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
TABLE_STATISTICS,
TABLE_STATISTICS_META,
TABLE_STATISTICS_RUNS,
TABLE_STATISTICS_SHORT_TERM,
]
DATETIME_TYPE = DateTime(timezone=True).with_variant(
mysql.DATETIME(timezone=True, fsp=6), "mysql"
)
DOUBLE_TYPE = (
Float()
.with_variant(mysql.DOUBLE(asdecimal=False), "mysql")
.with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
)
class Events(Base): # type: ignore
"""Event history data."""
__table_args__ = (
# Used for fetching events at a specific time
# see logbook
Index("ix_events_event_type_time_fired", "event_type", "time_fired"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True)
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN))
time_fired = Column(DATETIME_TYPE, index=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', data='{self.event_data}', "
f"origin='{self.origin}', time_fired='{self.time_fired}'"
f")>"
)
@staticmethod
def from_event(event, event_data=None):
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=event_data
or json.dumps(event.data, cls=JSONEncoder, separators=(",", ":")),
origin=str(event.origin.value),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id=True):
"""Convert to a native HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class States(Base): # type: ignore
"""State change history."""
__table_args__ = (
# Used for fetching the state of entities at a specific time
# (get_states in history.py)
Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True)
domain = Column(String(MAX_LENGTH_STATE_DOMAIN))
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE))
attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
event_id = Column(
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
)
last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow)
last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True)
event = relationship("Events", uselist=False)
old_state = relationship("States", remote_side=[state_id])
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.States("
f"id={self.state_id}, domain='{self.domain}', entity_id='{self.entity_id}', "
f"state='{self.state}', event_id='{self.event_id}', "
f"last_updated='{self.last_updated.isoformat(sep=' ', timespec='seconds')}', "
f"old_state_id={self.old_state_id}"
f")>"
)
@staticmethod
def from_event(event):
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state = event.data.get("new_state")
dbstate = States(entity_id=entity_id)
# State got deleted
if state is None:
dbstate.state = ""
dbstate.domain = split_entity_id(entity_id)[0]
dbstate.attributes = "{}"
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.domain = state.domain
dbstate.state = state.state
dbstate.attributes = json.dumps(
dict(state.attributes), cls=JSONEncoder, separators=(",", ":")
)
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self, validate_entity_id=True):
"""Convert to an HA state object."""
try:
return State(
self.entity_id,
self.state,
json.loads(self.attributes),
process_timestamp(self.last_changed),
process_timestamp(self.last_updated),
# Join the events table on event_id to get the context instead
# as it will always be there for state_changed events
context=Context(id=None),
validate_entity_id=validate_entity_id,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class StatisticResult(TypedDict):
"""Statistic result data class.
Allows multiple datapoints for the same statistic_id.
"""
meta: StatisticMetaData
stat: Iterable[StatisticData]
class StatisticDataBase(TypedDict):
"""Mandatory fields for statistic data class."""
start: datetime
class StatisticData(StatisticDataBase, total=False):
"""Statistic data class."""
mean: float
min: float
max: float
last_reset: datetime | None
state: float
sum: float
class StatisticsBase:
"""Statistics base class."""
id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
@declared_attr
def metadata_id(self):
"""Define the metadata_id column for sub classes."""
return Column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE)
max = Column(DOUBLE_TYPE)
last_reset = Column(DATETIME_TYPE)
state = Column(DOUBLE_TYPE)
sum = Column(DOUBLE_TYPE)
@classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData):
"""Create object from a statistics."""
return cls( # type: ignore
metadata_id=metadata_id,
**stats,
)
class Statistics(Base, StatisticsBase): # type: ignore
"""Long term statistics."""
duration = timedelta(hours=1)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_statistic_id_start", "metadata_id", "start"),
)
__tablename__ = TABLE_STATISTICS
class StatisticsShortTerm(Base, StatisticsBase): # type: ignore
"""Short term statistics."""
duration = timedelta(minutes=5)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_short_term_statistic_id_start", "metadata_id", "start"),
)
__tablename__ = TABLE_STATISTICS_SHORT_TERM
class StatisticMetaData(TypedDict):
"""Statistic meta data class."""
statistic_id: str
unit_of_measurement: str | None
has_mean: bool
has_sum: bool
class StatisticsMeta(Base): # type: ignore
"""Statistics meta data."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True)
source = Column(String(32))
unit_of_measurement = Column(String(255))
has_mean = Column(Boolean)
has_sum = Column(Boolean)
@staticmethod
def from_meta(
source: str,
statistic_id: str,
unit_of_measurement: str | None,
has_mean: bool,
has_sum: bool,
) -> StatisticsMeta:
"""Create object from meta data."""
return StatisticsMeta(
source=source,
statistic_id=statistic_id,
unit_of_measurement=unit_of_measurement,
has_mean=has_mean,
has_sum=has_sum,
)
class RecorderRuns(Base): # type: ignore
"""Representation of recorder run."""
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
__tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), default=dt_util.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
end = (
f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None
)
return (
f"<recorder.RecorderRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f"end={end}, closed_incorrect={self.closed_incorrect}, "
f"created='{self.created.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def entity_ids(self, point_in_time=None):
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start
)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self, validate_entity_id=True):
"""Return self, native format is this model."""
return self
class SchemaChanges(Base): # type: ignore
"""Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer)
changed = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.SchemaChanges("
f"id={self.change_id}, schema_version={self.schema_version}, "
f"changed='{self.changed.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
class StatisticsRuns(Base): # type: ignore
"""Representation of statistics run."""
__tablename__ = TABLE_STATISTICS_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StatisticsRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f")>"
)
@overload
def process_timestamp(ts: None) -> None:
...
@overload
def process_timestamp(ts: datetime) -> datetime:
...
def process_timestamp(ts: datetime | None) -> datetime | None:
"""Process a timestamp into datetime object."""
if ts is None:
return None
if ts.tzinfo is None:
return ts.replace(tzinfo=dt_util.UTC)
return dt_util.as_utc(ts)
@overload
def process_timestamp_to_utc_isoformat(ts: None) -> None:
...
@overload
def process_timestamp_to_utc_isoformat(ts: datetime) -> str:
...
def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
"""Process a timestamp into UTC isotime."""
if ts is None:
return None
if ts.tzinfo == dt_util.UTC:
return ts.isoformat()
if ts.tzinfo is None:
return f"{ts.isoformat()}{DB_TIMEZONE}"
return ts.astimezone(dt_util.UTC).isoformat()
class LazyState(State):
"""A lazy version of core State."""
__slots__ = [
"_row",
"entity_id",
"state",
"_attributes",
"_last_changed",
"_last_updated",
"_context",
]
def __init__(self, row): # pylint: disable=super-init-not-called
"""Init the lazy state."""
self._row = row
self.entity_id = self._row.entity_id
self.state = self._row.state or ""
self._attributes = None
self._last_changed = None
self._last_updated = None
self._context = None
@property # type: ignore
def attributes(self):
"""State attributes."""
if not self._attributes:
try:
self._attributes = json.loads(self._row.attributes)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self._row)
self._attributes = {}
return self._attributes
@attributes.setter
def attributes(self, value):
"""Set attributes."""
self._attributes = value
@property # type: ignore
def context(self):
"""State context."""
if not self._context:
self._context = Context(id=None)
return self._context
@context.setter
def context(self, value):
"""Set context."""
self._context = value
@property # type: ignore
def last_changed(self):
"""Last changed datetime."""
if not self._last_changed:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value):
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore
def last_updated(self):
"""Last updated datetime."""
if not self._last_updated:
self._last_updated = process_timestamp(self._row.last_updated)
return self._last_updated
@last_updated.setter
def last_updated(self, value):
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self):
"""Return a dict representation of the LazyState.
Async friendly.
To be used for JSON serialization.
"""
if self._last_changed:
last_changed_isoformat = self._last_changed.isoformat()
else:
last_changed_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_changed
)
if self._last_updated:
last_updated_isoformat = self._last_updated.isoformat()
else:
last_updated_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_updated
)
return {
"entity_id": self.entity_id,
"state": self.state,
"attributes": self._attributes or self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
}
def __eq__(self, other):
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]
and self.entity_id == other.entity_id
and self.state == other.state
and self.attributes == other.attributes
)

View File

@ -1,7 +1,10 @@
"""The tests for the Recorder component."""
# pylint: disable=protected-access
import datetime
import importlib
import sqlite3
import sys
import threading
from unittest.mock import ANY, Mock, PropertyMock, call, patch
import pytest
@ -222,7 +225,8 @@ async def test_events_during_migration_queue_exhausted(hass):
assert len(db_states) == 2
async def test_schema_migrate(hass):
@pytest.mark.parametrize("start_version", [0, 16, 18, 22])
async def test_schema_migrate(hass, start_version):
"""Test the full schema migration logic.
We're just testing that the logic can execute successfully here without
@ -230,21 +234,76 @@ async def test_schema_migrate(hass):
inspection could quickly become quite cumbersome.
"""
migration_done = threading.Event()
migration_stall = threading.Event()
migration_version = None
real_migration = recorder.migration.migrate_schema
def _create_engine_test(*args, **kwargs):
"""Test version of create_engine that initializes with old schema.
This simulates an existing db with the old schema.
"""
module = f"tests.components.recorder.models_schema_{str(start_version)}"
importlib.import_module(module)
old_models = sys.modules[module]
engine = create_engine(*args, **kwargs)
old_models.Base.metadata.create_all(engine)
if start_version > 0:
with Session(engine) as session:
session.add(recorder.models.SchemaChanges(schema_version=start_version))
session.commit()
return engine
def _mock_setup_run(self):
self.run_info = RecorderRuns(
start=self.recording_start, created=dt_util.utcnow()
)
with patch("sqlalchemy.create_engine", new=create_engine_test), patch(
def _instrument_migration(*args):
"""Control migration progress and check results."""
nonlocal migration_done
nonlocal migration_version
nonlocal migration_stall
migration_stall.wait()
try:
real_migration(*args)
except Exception:
migration_done.set()
raise
# Check and report the outcome of the migration; if migration fails
# the recorder will silently create a new database.
with session_scope(hass=hass) as session:
res = (
session.query(models.SchemaChanges)
.order_by(models.SchemaChanges.change_id.desc())
.first()
)
migration_version = res.schema_version
migration_done.set()
with patch(
"homeassistant.components.recorder.create_engine", new=_create_engine_test
), patch(
"homeassistant.components.recorder.Recorder._setup_run",
side_effect=_mock_setup_run,
autospec=True,
) as setup_run:
) as setup_run, patch(
"homeassistant.components.recorder.migration.migrate_schema",
wraps=_instrument_migration,
):
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
assert await recorder.async_migration_in_progress(hass) is True
migration_stall.set()
await hass.async_block_till_done()
migration_done.wait()
await async_wait_recording_done_without_instance(hass)
assert migration_version == models.SCHEMA_VERSION
assert setup_run.called
assert await recorder.async_migration_in_progress(hass) is not True
def test_invalid_update():