Improve reliability of context id migration (#89609)

* Split context id migration into states and events tasks

Since events can finish much earlier than states we
would keep looking at the table because states as not
done. Make them seperate tasks

* add retry dec

* fix migration happening twice

* another case
This commit is contained in:
J. Nick Koston 2023-03-12 15:41:48 -10:00 committed by GitHub
parent 85ca94e9d4
commit b9ac6b4a7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 258 additions and 47 deletions

View File

@ -97,9 +97,9 @@ from .tasks import (
ChangeStatisticsUnitTask,
ClearStatisticsTask,
CommitTask,
ContextIDMigrationTask,
DatabaseLockTask,
EntityIDMigrationTask,
EventsContextIDMigrationTask,
EventTask,
EventTypeIDMigrationTask,
ImportStatisticsTask,
@ -107,6 +107,7 @@ from .tasks import (
PerodicCleanupTask,
PurgeTask,
RecorderTask,
StatesContextIDMigrationTask,
StatisticsTask,
StopTask,
SynchronizeTask,
@ -654,8 +655,9 @@ class Recorder(threading.Thread):
self.migration_is_live = migration.live_migration(schema_status)
self.hass.add_job(self.async_connection_success)
database_was_ready = self.migration_is_live or schema_status.valid
if self.migration_is_live or schema_status.valid:
if database_was_ready:
# If the migrate is live or the schema is valid, we need to
# wait for startup to complete. If its not live, we need to continue
# on.
@ -670,7 +672,6 @@ class Recorder(threading.Thread):
# Make sure we cleanly close the run if
# we restart before startup finishes
self._shutdown()
self._activate_and_set_db_ready()
return
if not schema_status.valid:
@ -692,7 +693,8 @@ class Recorder(threading.Thread):
self._shutdown()
return
self._activate_and_set_db_ready()
if not database_was_ready:
self._activate_and_set_db_ready()
# Catch up with missed statistics
with session_scope(session=self.get_session()) as session:
@ -710,9 +712,14 @@ class Recorder(threading.Thread):
if (
self.schema_version < 36
or session.execute(has_events_context_ids_to_migrate()).scalar()
):
self.queue_task(StatesContextIDMigrationTask())
if (
self.schema_version < 36
or session.execute(has_states_context_ids_to_migrate()).scalar()
):
self.queue_task(ContextIDMigrationTask())
self.queue_task(EventsContextIDMigrationTask())
if (
self.schema_version < 37
@ -1236,9 +1243,13 @@ class Recorder(threading.Thread):
"""Run post schema migration tasks."""
migration.post_schema_migration(self, old_version, new_version)
def _migrate_context_ids(self) -> bool:
"""Migrate context ids if needed."""
return migration.migrate_context_ids(self)
def _migrate_states_context_ids(self) -> bool:
"""Migrate states context ids if needed."""
return migration.migrate_states_context_ids(self)
def _migrate_events_context_ids(self) -> bool:
"""Migrate events context ids if needed."""
return migration.migrate_events_context_ids(self)
def _migrate_event_type_ids(self) -> bool:
"""Migrate event type ids if needed."""

View File

@ -64,7 +64,7 @@ from .tasks import (
PostSchemaMigrationTask,
StatisticsTimestampMigrationCleanupTask,
)
from .util import database_job_retry_wrapper, session_scope
from .util import database_job_retry_wrapper, retryable_database_job, session_scope
if TYPE_CHECKING:
from . import Recorder
@ -1301,8 +1301,43 @@ def _context_id_to_bytes(context_id: str | None) -> bytes | None:
return None
def migrate_context_ids(instance: Recorder) -> bool:
"""Migrate context_ids to use binary format."""
@retryable_database_job("migrate states context_ids to binary format")
def migrate_states_context_ids(instance: Recorder) -> bool:
"""Migrate states context_ids to use binary format."""
_to_bytes = _context_id_to_bytes
session_maker = instance.get_session
_LOGGER.debug("Migrating states context_ids to binary format")
with session_scope(session=session_maker()) as session:
if states := session.execute(find_states_context_ids_to_migrate()).all():
session.execute(
update(States),
[
{
"state_id": state_id,
"context_id": None,
"context_id_bin": _to_bytes(context_id) or _EMPTY_CONTEXT_ID,
"context_user_id": None,
"context_user_id_bin": _to_bytes(context_user_id),
"context_parent_id": None,
"context_parent_id_bin": _to_bytes(context_parent_id),
}
for state_id, context_id, context_user_id, context_parent_id in states
],
)
# If there is more work to do return False
# so that we can be called again
is_done = not states
if is_done:
_drop_index(session_maker, "states", "ix_states_context_id")
_LOGGER.debug("Migrating states context_ids to binary format: done=%s", is_done)
return is_done
@retryable_database_job("migrate events context_ids to binary format")
def migrate_events_context_ids(instance: Recorder) -> bool:
"""Migrate events context_ids to use binary format."""
_to_bytes = _context_id_to_bytes
session_maker = instance.get_session
_LOGGER.debug("Migrating context_ids to binary format")
@ -1323,34 +1358,18 @@ def migrate_context_ids(instance: Recorder) -> bool:
for event_id, context_id, context_user_id, context_parent_id in events
],
)
if states := session.execute(find_states_context_ids_to_migrate()).all():
session.execute(
update(States),
[
{
"state_id": state_id,
"context_id": None,
"context_id_bin": _to_bytes(context_id) or _EMPTY_CONTEXT_ID,
"context_user_id": None,
"context_user_id_bin": _to_bytes(context_user_id),
"context_parent_id": None,
"context_parent_id_bin": _to_bytes(context_parent_id),
}
for state_id, context_id, context_user_id, context_parent_id in states
],
)
# If there is more work to do return False
# so that we can be called again
is_done = not (events or states)
is_done = not events
if is_done:
_drop_index(session_maker, "events", "ix_events_context_id")
_drop_index(session_maker, "states", "ix_states_context_id")
_LOGGER.debug("Migrating context_ids to binary format: done=%s", is_done)
_LOGGER.debug("Migrating events context_ids to binary format: done=%s", is_done)
return is_done
@retryable_database_job("migrate events event_types to event_type_ids")
def migrate_event_type_ids(instance: Recorder) -> bool:
"""Migrate event_type to event_type_ids."""
session_maker = instance.get_session
@ -1407,6 +1426,7 @@ def migrate_event_type_ids(instance: Recorder) -> bool:
return is_done
@retryable_database_job("migrate states entity_ids to states_meta")
def migrate_entity_ids(instance: Recorder) -> bool:
"""Migrate entity_ids to states_meta.
@ -1468,6 +1488,7 @@ def migrate_entity_ids(instance: Recorder) -> bool:
return is_done
@retryable_database_job("post migrate states entity_ids to states_meta")
def post_migrate_entity_ids(instance: Recorder) -> bool:
"""Remove old entity_id strings from states.

View File

@ -346,16 +346,33 @@ class AdjustLRUSizeTask(RecorderTask):
@dataclass
class ContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate context ids."""
class StatesContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate states context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if not instance._migrate_context_ids(): # pylint: disable=[protected-access]
if (
not instance._migrate_states_context_ids() # pylint: disable=[protected-access]
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(ContextIDMigrationTask())
instance.queue_task(StatesContextIDMigrationTask())
@dataclass
class EventsContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate events context ids."""
commit_before = False
def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_events_context_ids() # pylint: disable=[protected-access]
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventsContextIDMigrationTask())
@dataclass

View File

@ -32,10 +32,11 @@ from homeassistant.components.recorder.db_schema import (
)
from homeassistant.components.recorder.queries import select_event_type_ids
from homeassistant.components.recorder.tasks import (
ContextIDMigrationTask,
EntityIDMigrationTask,
EntityIDPostMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask,
StatesContextIDMigrationTask,
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
@ -558,7 +559,7 @@ def test_raise_if_exception_missing_empty_cause_str() -> None:
@pytest.mark.parametrize("enable_migrate_context_ids", [True])
async def test_migrate_context_ids(
async def test_migrate_events_context_ids(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test we can migrate old uuid context ids and ulid context ids to binary format."""
@ -632,7 +633,7 @@ async def test_migrate_context_ids(
await async_wait_recording_done(hass)
# This is a threadsafe way to add a task to the recorder
instance.queue_task(ContextIDMigrationTask())
instance.queue_task(EventsContextIDMigrationTask())
await async_recorder_block_till_done(hass)
def _object_as_dict(obj):
@ -701,6 +702,137 @@ async def test_migrate_context_ids(
assert invalid_context_id_event["context_parent_id_bin"] is None
@pytest.mark.parametrize("enable_migrate_context_ids", [True])
async def test_migrate_states_context_ids(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant
) -> None:
"""Test we can migrate old uuid context ids and ulid context ids to binary format."""
instance = await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass)
test_uuid = uuid.uuid4()
uuid_hex = test_uuid.hex
uuid_bin = test_uuid.bytes
def _insert_events():
with session_scope(hass=hass) as session:
session.add_all(
(
States(
entity_id="state.old_uuid_context_id",
last_updated_ts=1677721632.452529,
context_id=uuid_hex,
context_id_bin=None,
context_user_id=None,
context_user_id_bin=None,
context_parent_id=None,
context_parent_id_bin=None,
),
States(
entity_id="state.empty_context_id",
last_updated_ts=1677721632.552529,
context_id=None,
context_id_bin=None,
context_user_id=None,
context_user_id_bin=None,
context_parent_id=None,
context_parent_id_bin=None,
),
States(
entity_id="state.ulid_context_id",
last_updated_ts=1677721632.552529,
context_id="01ARZ3NDEKTSV4RRFFQ69G5FAV",
context_id_bin=None,
context_user_id="9400facee45711eaa9308bfd3d19e474",
context_user_id_bin=None,
context_parent_id="01ARZ3NDEKTSV4RRFFQ69G5FA2",
context_parent_id_bin=None,
),
States(
entity_id="state.invalid_context_id",
last_updated_ts=1677721632.552529,
context_id="invalid",
context_id_bin=None,
context_user_id=None,
context_user_id_bin=None,
context_parent_id=None,
context_parent_id_bin=None,
),
)
)
await instance.async_add_executor_job(_insert_events)
await async_wait_recording_done(hass)
# This is a threadsafe way to add a task to the recorder
instance.queue_task(StatesContextIDMigrationTask())
await async_recorder_block_till_done(hass)
def _object_as_dict(obj):
return {c.key: getattr(obj, c.key) for c in inspect(obj).mapper.column_attrs}
def _fetch_migrated_states():
with session_scope(hass=hass) as session:
events = (
session.query(States)
.filter(
States.entity_id.in_(
[
"state.old_uuid_context_id",
"state.empty_context_id",
"state.ulid_context_id",
"state.invalid_context_id",
]
)
)
.all()
)
assert len(events) == 4
return {state.entity_id: _object_as_dict(state) for state in events}
states_by_entity_id = await instance.async_add_executor_job(_fetch_migrated_states)
old_uuid_context_id = states_by_entity_id["state.old_uuid_context_id"]
assert old_uuid_context_id["context_id"] is None
assert old_uuid_context_id["context_user_id"] is None
assert old_uuid_context_id["context_parent_id"] is None
assert old_uuid_context_id["context_id_bin"] == uuid_bin
assert old_uuid_context_id["context_user_id_bin"] is None
assert old_uuid_context_id["context_parent_id_bin"] is None
empty_context_id = states_by_entity_id["state.empty_context_id"]
assert empty_context_id["context_id"] is None
assert empty_context_id["context_user_id"] is None
assert empty_context_id["context_parent_id"] is None
assert empty_context_id["context_id_bin"] == b"\x00" * 16
assert empty_context_id["context_user_id_bin"] is None
assert empty_context_id["context_parent_id_bin"] is None
ulid_context_id = states_by_entity_id["state.ulid_context_id"]
assert ulid_context_id["context_id"] is None
assert ulid_context_id["context_user_id"] is None
assert ulid_context_id["context_parent_id"] is None
assert (
bytes_to_ulid(ulid_context_id["context_id_bin"]) == "01ARZ3NDEKTSV4RRFFQ69G5FAV"
)
assert (
ulid_context_id["context_user_id_bin"]
== b"\x94\x00\xfa\xce\xe4W\x11\xea\xa90\x8b\xfd=\x19\xe4t"
)
assert (
bytes_to_ulid(ulid_context_id["context_parent_id_bin"])
== "01ARZ3NDEKTSV4RRFFQ69G5FA2"
)
invalid_context_id = states_by_entity_id["state.invalid_context_id"]
assert invalid_context_id["context_id"] is None
assert invalid_context_id["context_user_id"] is None
assert invalid_context_id["context_parent_id"] is None
assert invalid_context_id["context_id_bin"] == b"\x00" * 16
assert invalid_context_id["context_user_id_bin"] is None
assert invalid_context_id["context_parent_id_bin"] is None
@pytest.mark.parametrize("enable_migrate_event_type_ids", [True])
async def test_migrate_event_type_ids(
async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant

View File

@ -86,6 +86,7 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
EventOrigin.local,
time_fired=now,
)
number_of_migrations = 5
with patch.object(recorder, "db_schema", old_db_schema), patch.object(
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
@ -100,11 +101,15 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
), patch(
CREATE_ENGINE_TARGET, new=_create_engine_test
), patch(
"homeassistant.components.recorder.Recorder._migrate_context_ids",
"homeassistant.components.recorder.Recorder._migrate_events_context_ids",
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids",
), patch(
"homeassistant.components.recorder.Recorder._migrate_entity_ids",
), patch(
"homeassistant.components.recorder.Recorder._post_migrate_entity_ids"
):
hass = await async_test_home_assistant(asyncio.get_running_loop())
recorder_helper.async_initialize_recorder(hass)
@ -122,8 +127,10 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
await recorder.get_instance(hass).async_add_executor_job(_add_data)
await hass.async_block_till_done()
await recorder.get_instance(hass).async_block_till_done()
await hass.async_stop()
await hass.async_block_till_done()
dt_util.DEFAULT_TIME_ZONE = ORIG_TZ
@ -137,7 +144,8 @@ async def test_migrate_times(caplog: pytest.LogCaptureFixture, tmpdir) -> None:
# We need to wait for all the migration tasks to complete
# before we can check the database.
for _ in range(5):
for _ in range(number_of_migrations):
await recorder.get_instance(hass).async_block_till_done()
await async_wait_recording_done(hass)
def _get_test_data_from_db():

View File

@ -1250,8 +1250,15 @@ def hass_recorder(
if enable_statistics_table_validation
else itertools.repeat(set())
)
migrate_context_ids = (
recorder.Recorder._migrate_context_ids if enable_migrate_context_ids else None
migrate_states_context_ids = (
recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
)
migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids
@ -1274,8 +1281,12 @@ def hass_recorder(
side_effect=stats_validate,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_context_ids",
side_effect=migrate_context_ids,
"homeassistant.components.recorder.Recorder._migrate_events_context_ids",
side_effect=migrate_events_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids",
@ -1354,8 +1365,15 @@ async def async_setup_recorder_instance(
if enable_statistics_table_validation
else itertools.repeat(set())
)
migrate_context_ids = (
recorder.Recorder._migrate_context_ids if enable_migrate_context_ids else None
migrate_states_context_ids = (
recorder.Recorder._migrate_states_context_ids
if enable_migrate_context_ids
else None
)
migrate_events_context_ids = (
recorder.Recorder._migrate_events_context_ids
if enable_migrate_context_ids
else None
)
migrate_event_type_ids = (
recorder.Recorder._migrate_event_type_ids
@ -1378,8 +1396,12 @@ async def async_setup_recorder_instance(
side_effect=stats_validate,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_context_ids",
side_effect=migrate_context_ids,
"homeassistant.components.recorder.Recorder._migrate_events_context_ids",
side_effect=migrate_events_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_states_context_ids",
side_effect=migrate_states_context_ids,
autospec=True,
), patch(
"homeassistant.components.recorder.Recorder._migrate_event_type_ids",