Don't use shared session during recorder migration (#65672)

This commit is contained in:
Erik Montnemery 2022-02-04 18:55:11 +01:00 committed by Paulus Schoutsen
parent 4e3cd1471a
commit 9cd6bb7335
2 changed files with 176 additions and 143 deletions

View File

@ -68,20 +68,18 @@ def schema_is_current(current_version):
def migrate_schema(instance, current_version): def migrate_schema(instance, current_version):
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
with session_scope(session=instance.get_session()) as session: _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
_LOGGER.warning(
"Database is about to upgrade. Schema version: %s", current_version
)
for version in range(current_version, SCHEMA_VERSION): for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1 new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version) _LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(instance, session, new_version, current_version) _apply_update(instance, new_version, current_version)
with session_scope(session=instance.get_session()) as session:
session.add(SchemaChanges(schema_version=new_version)) session.add(SchemaChanges(schema_version=new_version))
_LOGGER.info("Upgrade to version %s done", new_version) _LOGGER.info("Upgrade to version %s done", new_version)
def _create_index(connection, table_name, index_name): def _create_index(instance, table_name, index_name):
"""Create an index for the specified table. """Create an index for the specified table.
The index name should match the name given for the index The index name should match the name given for the index
@ -103,6 +101,8 @@ def _create_index(connection, table_name, index_name):
index_name, index_name,
) )
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
index.create(connection) index.create(connection)
except (InternalError, ProgrammingError, OperationalError) as err: except (InternalError, ProgrammingError, OperationalError) as err:
raise_if_exception_missing_str(err, ["already exists", "duplicate"]) raise_if_exception_missing_str(err, ["already exists", "duplicate"])
@ -113,7 +113,7 @@ def _create_index(connection, table_name, index_name):
_LOGGER.debug("Finished creating %s", index_name) _LOGGER.debug("Finished creating %s", index_name)
def _drop_index(connection, table_name, index_name): def _drop_index(instance, table_name, index_name):
"""Drop an index from a specified table. """Drop an index from a specified table.
There is no universal way to do something like `DROP INDEX IF EXISTS` There is no universal way to do something like `DROP INDEX IF EXISTS`
@ -129,6 +129,8 @@ def _drop_index(connection, table_name, index_name):
# Engines like DB2/Oracle # Engines like DB2/Oracle
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(text(f"DROP INDEX {index_name}")) connection.execute(text(f"DROP INDEX {index_name}"))
except SQLAlchemyError: except SQLAlchemyError:
pass pass
@ -138,6 +140,8 @@ def _drop_index(connection, table_name, index_name):
# Engines like SQLite, SQL Server # Engines like SQLite, SQL Server
if not success: if not success:
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
text( text(
"DROP INDEX {table}.{index}".format( "DROP INDEX {table}.{index}".format(
@ -153,6 +157,8 @@ def _drop_index(connection, table_name, index_name):
if not success: if not success:
# Engines like MySQL, MS Access # Engines like MySQL, MS Access
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
text( text(
"DROP INDEX {index} ON {table}".format( "DROP INDEX {index} ON {table}".format(
@ -184,7 +190,7 @@ def _drop_index(connection, table_name, index_name):
) )
def _add_columns(connection, table_name, columns_def): def _add_columns(instance, table_name, columns_def):
"""Add columns to a table.""" """Add columns to a table."""
_LOGGER.warning( _LOGGER.warning(
"Adding columns %s to table %s. Note: this can take several " "Adding columns %s to table %s. Note: this can take several "
@ -197,6 +203,8 @@ def _add_columns(connection, table_name, columns_def):
columns_def = [f"ADD {col_def}" for col_def in columns_def] columns_def = [f"ADD {col_def}" for col_def in columns_def]
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
text( text(
"ALTER TABLE {table} {columns_def}".format( "ALTER TABLE {table} {columns_def}".format(
@ -212,6 +220,8 @@ def _add_columns(connection, table_name, columns_def):
for column_def in columns_def: for column_def in columns_def:
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
text( text(
"ALTER TABLE {table} {column_def}".format( "ALTER TABLE {table} {column_def}".format(
@ -228,7 +238,7 @@ def _add_columns(connection, table_name, columns_def):
) )
def _modify_columns(connection, engine, table_name, columns_def): def _modify_columns(instance, engine, table_name, columns_def):
"""Modify columns in a table.""" """Modify columns in a table."""
if engine.dialect.name == "sqlite": if engine.dialect.name == "sqlite":
_LOGGER.debug( _LOGGER.debug(
@ -261,6 +271,8 @@ def _modify_columns(connection, engine, table_name, columns_def):
columns_def = [f"MODIFY {col_def}" for col_def in columns_def] columns_def = [f"MODIFY {col_def}" for col_def in columns_def]
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
text( text(
"ALTER TABLE {table} {columns_def}".format( "ALTER TABLE {table} {columns_def}".format(
@ -274,6 +286,8 @@ def _modify_columns(connection, engine, table_name, columns_def):
for column_def in columns_def: for column_def in columns_def:
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
text( text(
"ALTER TABLE {table} {column_def}".format( "ALTER TABLE {table} {column_def}".format(
@ -287,7 +301,7 @@ def _modify_columns(connection, engine, table_name, columns_def):
) )
def _update_states_table_with_foreign_key_options(connection, engine): def _update_states_table_with_foreign_key_options(instance, engine):
"""Add the options to foreign key constraints.""" """Add the options to foreign key constraints."""
inspector = sqlalchemy.inspect(engine) inspector = sqlalchemy.inspect(engine)
alters = [] alters = []
@ -316,6 +330,8 @@ def _update_states_table_with_foreign_key_options(connection, engine):
for alter in alters: for alter in alters:
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(DropConstraint(alter["old_fk"])) connection.execute(DropConstraint(alter["old_fk"]))
for fkc in states_key_constraints: for fkc in states_key_constraints:
if fkc.column_keys == alter["columns"]: if fkc.column_keys == alter["columns"]:
@ -326,7 +342,7 @@ def _update_states_table_with_foreign_key_options(connection, engine):
) )
def _drop_foreign_key_constraints(connection, engine, table, columns): def _drop_foreign_key_constraints(instance, engine, table, columns):
"""Drop foreign key constraints for a table on specific columns.""" """Drop foreign key constraints for a table on specific columns."""
inspector = sqlalchemy.inspect(engine) inspector = sqlalchemy.inspect(engine)
drops = [] drops = []
@ -345,6 +361,8 @@ def _drop_foreign_key_constraints(connection, engine, table, columns):
for drop in drops: for drop in drops:
try: try:
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(DropConstraint(drop)) connection.execute(DropConstraint(drop))
except (InternalError, OperationalError): except (InternalError, OperationalError):
_LOGGER.exception( _LOGGER.exception(
@ -354,17 +372,16 @@ def _drop_foreign_key_constraints(connection, engine, table, columns):
) )
def _apply_update(instance, session, new_version, old_version): # noqa: C901 def _apply_update(instance, new_version, old_version): # noqa: C901
"""Perform operations to bring schema up to date.""" """Perform operations to bring schema up to date."""
engine = instance.engine engine = instance.engine
connection = session.connection()
if new_version == 1: if new_version == 1:
_create_index(connection, "events", "ix_events_time_fired") _create_index(instance, "events", "ix_events_time_fired")
elif new_version == 2: elif new_version == 2:
# Create compound start/end index for recorder_runs # Create compound start/end index for recorder_runs
_create_index(connection, "recorder_runs", "ix_recorder_runs_start_end") _create_index(instance, "recorder_runs", "ix_recorder_runs_start_end")
# Create indexes for states # Create indexes for states
_create_index(connection, "states", "ix_states_last_updated") _create_index(instance, "states", "ix_states_last_updated")
elif new_version == 3: elif new_version == 3:
# There used to be a new index here, but it was removed in version 4. # There used to be a new index here, but it was removed in version 4.
pass pass
@ -374,41 +391,41 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
if old_version == 3: if old_version == 3:
# Remove index that was added in version 3 # Remove index that was added in version 3
_drop_index(connection, "states", "ix_states_created_domain") _drop_index(instance, "states", "ix_states_created_domain")
if old_version == 2: if old_version == 2:
# Remove index that was added in version 2 # Remove index that was added in version 2
_drop_index(connection, "states", "ix_states_entity_id_created") _drop_index(instance, "states", "ix_states_entity_id_created")
# Remove indexes that were added in version 0 # Remove indexes that were added in version 0
_drop_index(connection, "states", "states__state_changes") _drop_index(instance, "states", "states__state_changes")
_drop_index(connection, "states", "states__significant_changes") _drop_index(instance, "states", "states__significant_changes")
_drop_index(connection, "states", "ix_states_entity_id_created") _drop_index(instance, "states", "ix_states_entity_id_created")
_create_index(connection, "states", "ix_states_entity_id_last_updated") _create_index(instance, "states", "ix_states_entity_id_last_updated")
elif new_version == 5: elif new_version == 5:
# Create supporting index for States.event_id foreign key # Create supporting index for States.event_id foreign key
_create_index(connection, "states", "ix_states_event_id") _create_index(instance, "states", "ix_states_event_id")
elif new_version == 6: elif new_version == 6:
_add_columns( _add_columns(
session, instance,
"events", "events",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
) )
_create_index(connection, "events", "ix_events_context_id") _create_index(instance, "events", "ix_events_context_id")
_create_index(connection, "events", "ix_events_context_user_id") _create_index(instance, "events", "ix_events_context_user_id")
_add_columns( _add_columns(
connection, instance,
"states", "states",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
) )
_create_index(connection, "states", "ix_states_context_id") _create_index(instance, "states", "ix_states_context_id")
_create_index(connection, "states", "ix_states_context_user_id") _create_index(instance, "states", "ix_states_context_user_id")
elif new_version == 7: elif new_version == 7:
_create_index(connection, "states", "ix_states_entity_id") _create_index(instance, "states", "ix_states_entity_id")
elif new_version == 8: elif new_version == 8:
_add_columns(connection, "events", ["context_parent_id CHARACTER(36)"]) _add_columns(instance, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(connection, "states", ["old_state_id INTEGER"]) _add_columns(instance, "states", ["old_state_id INTEGER"])
_create_index(connection, "events", "ix_events_context_parent_id") _create_index(instance, "events", "ix_events_context_parent_id")
elif new_version == 9: elif new_version == 9:
# We now get the context from events with a join # We now get the context from events with a join
# since its always there on state_changed events # since its always there on state_changed events
@ -418,36 +435,36 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
# and we would have to move to something like # and we would have to move to something like
# sqlalchemy alembic to make that work # sqlalchemy alembic to make that work
# #
_drop_index(connection, "states", "ix_states_context_id") _drop_index(instance, "states", "ix_states_context_id")
_drop_index(connection, "states", "ix_states_context_user_id") _drop_index(instance, "states", "ix_states_context_user_id")
# This index won't be there if they were not running # This index won't be there if they were not running
# nightly but we don't treat that as a critical issue # nightly but we don't treat that as a critical issue
_drop_index(connection, "states", "ix_states_context_parent_id") _drop_index(instance, "states", "ix_states_context_parent_id")
# Redundant keys on composite index: # Redundant keys on composite index:
# We already have ix_states_entity_id_last_updated # We already have ix_states_entity_id_last_updated
_drop_index(connection, "states", "ix_states_entity_id") _drop_index(instance, "states", "ix_states_entity_id")
_create_index(connection, "events", "ix_events_event_type_time_fired") _create_index(instance, "events", "ix_events_event_type_time_fired")
_drop_index(connection, "events", "ix_events_event_type") _drop_index(instance, "events", "ix_events_event_type")
elif new_version == 10: elif new_version == 10:
# Now done in step 11 # Now done in step 11
pass pass
elif new_version == 11: elif new_version == 11:
_create_index(connection, "states", "ix_states_old_state_id") _create_index(instance, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(connection, engine) _update_states_table_with_foreign_key_options(instance, engine)
elif new_version == 12: elif new_version == 12:
if engine.dialect.name == "mysql": if engine.dialect.name == "mysql":
_modify_columns(connection, engine, "events", ["event_data LONGTEXT"]) _modify_columns(instance, engine, "events", ["event_data LONGTEXT"])
_modify_columns(connection, engine, "states", ["attributes LONGTEXT"]) _modify_columns(instance, engine, "states", ["attributes LONGTEXT"])
elif new_version == 13: elif new_version == 13:
if engine.dialect.name == "mysql": if engine.dialect.name == "mysql":
_modify_columns( _modify_columns(
connection, instance,
engine, engine,
"events", "events",
["time_fired DATETIME(6)", "created DATETIME(6)"], ["time_fired DATETIME(6)", "created DATETIME(6)"],
) )
_modify_columns( _modify_columns(
connection, instance,
engine, engine,
"states", "states",
[ [
@ -457,14 +474,12 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
], ],
) )
elif new_version == 14: elif new_version == 14:
_modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"]) _modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
elif new_version == 15: elif new_version == 15:
# This dropped the statistics table, done again in version 18. # This dropped the statistics table, done again in version 18.
pass pass
elif new_version == 16: elif new_version == 16:
_drop_foreign_key_constraints( _drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"])
connection, engine, TABLE_STATES, ["old_state_id"]
)
elif new_version == 17: elif new_version == 17:
# This dropped the statistics table, done again in version 18. # This dropped the statistics table, done again in version 18.
pass pass
@ -489,12 +504,13 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
elif new_version == 19: elif new_version == 19:
# This adds the statistic runs table, insert a fake run to prevent duplicating # This adds the statistic runs table, insert a fake run to prevent duplicating
# statistics. # statistics.
with session_scope(session=instance.get_session()) as session:
session.add(StatisticsRuns(start=get_start_time())) session.add(StatisticsRuns(start=get_start_time()))
elif new_version == 20: elif new_version == 20:
# This changed the precision of statistics from float to double # This changed the precision of statistics from float to double
if engine.dialect.name in ["mysql", "postgresql"]: if engine.dialect.name in ["mysql", "postgresql"]:
_modify_columns( _modify_columns(
connection, instance,
engine, engine,
"statistics", "statistics",
[ [
@ -516,6 +532,8 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
table, table,
) )
with contextlib.suppress(SQLAlchemyError): with contextlib.suppress(SQLAlchemyError):
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute( connection.execute(
# Using LOCK=EXCLUSIVE to prevent the database from corrupting # Using LOCK=EXCLUSIVE to prevent the database from corrupting
# https://github.com/home-assistant/core/issues/56104 # https://github.com/home-assistant/core/issues/56104
@ -549,8 +567,11 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
# Block 5-minute statistics for one hour from the last run, or it will overlap # Block 5-minute statistics for one hour from the last run, or it will overlap
# with existing hourly statistics. Don't block on a database with no existing # with existing hourly statistics. Don't block on a database with no existing
# statistics. # statistics.
with session_scope(session=instance.get_session()) as session:
if session.query(Statistics.id).count() and ( if session.query(Statistics.id).count() and (
last_run_string := session.query(func.max(StatisticsRuns.start)).scalar() last_run_string := session.query(
func.max(StatisticsRuns.start)
).scalar()
): ):
last_run_start_time = process_timestamp(last_run_string) last_run_start_time = process_timestamp(last_run_string)
if last_run_start_time: if last_run_start_time:
@ -562,7 +583,10 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
# When querying the database, be careful to only explicitly query for columns # When querying the database, be careful to only explicitly query for columns
# which were present in schema version 21. If querying the table, SQLAlchemy # which were present in schema version 21. If querying the table, SQLAlchemy
# will refer to future columns. # will refer to future columns.
for sum_statistic in session.query(StatisticsMeta.id).filter_by(has_sum=true()): with session_scope(session=instance.get_session()) as session:
for sum_statistic in session.query(StatisticsMeta.id).filter_by(
has_sum=true()
):
last_statistic = ( last_statistic = (
session.query( session.query(
Statistics.start, Statistics.start,
@ -586,20 +610,21 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
) )
elif new_version == 23: elif new_version == 23:
# Add name column to StatisticsMeta # Add name column to StatisticsMeta
_add_columns(session, "statistics_meta", ["name VARCHAR(255)"]) _add_columns(instance, "statistics_meta", ["name VARCHAR(255)"])
elif new_version == 24: elif new_version == 24:
# Delete duplicated statistics # Delete duplicated statistics
with session_scope(session=instance.get_session()) as session:
delete_duplicates(instance, session) delete_duplicates(instance, session)
# Recreate statistics indices to block duplicated statistics # Recreate statistics indices to block duplicated statistics
_drop_index(connection, "statistics", "ix_statistics_statistic_id_start") _drop_index(instance, "statistics", "ix_statistics_statistic_id_start")
_create_index(connection, "statistics", "ix_statistics_statistic_id_start") _create_index(instance, "statistics", "ix_statistics_statistic_id_start")
_drop_index( _drop_index(
connection, instance,
"statistics_short_term", "statistics_short_term",
"ix_statistics_short_term_statistic_id_start", "ix_statistics_short_term_statistic_id_start",
) )
_create_index( _create_index(
connection, instance,
"statistics_short_term", "statistics_short_term",
"ix_statistics_short_term_statistic_id_start", "ix_statistics_short_term_statistic_id_start",
) )

View File

@ -5,7 +5,7 @@ import importlib
import sqlite3 import sqlite3
import sys import sys
import threading import threading
from unittest.mock import ANY, Mock, PropertyMock, call, patch from unittest.mock import Mock, PropertyMock, call, patch
import pytest import pytest
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
@ -57,7 +57,7 @@ async def test_schema_update_calls(hass):
assert recorder.util.async_migration_in_progress(hass) is False assert recorder.util.async_migration_in_progress(hass) is False
update.assert_has_calls( update.assert_has_calls(
[ [
call(hass.data[DATA_INSTANCE], ANY, version + 1, 0) call(hass.data[DATA_INSTANCE], version + 1, 0)
for version in range(0, models.SCHEMA_VERSION) for version in range(0, models.SCHEMA_VERSION)
] ]
) )
@ -309,7 +309,7 @@ async def test_schema_migrate(hass, start_version):
def test_invalid_update(): def test_invalid_update():
"""Test that an invalid new version raises an exception.""" """Test that an invalid new version raises an exception."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
migration._apply_update(Mock(), Mock(), -1, 0) migration._apply_update(Mock(), -1, 0)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -324,9 +324,13 @@ def test_invalid_update():
def test_modify_column(engine_type, substr): def test_modify_column(engine_type, substr):
"""Test that modify column generates the expected query.""" """Test that modify column generates the expected query."""
connection = Mock() connection = Mock()
session = Mock()
session.connection = Mock(return_value=connection)
instance = Mock()
instance.get_session = Mock(return_value=session)
engine = Mock() engine = Mock()
engine.dialect.name = engine_type engine.dialect.name = engine_type
migration._modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"]) migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
if substr: if substr:
assert substr in connection.execute.call_args[0][0].text assert substr in connection.execute.call_args[0][0].text
else: else:
@ -338,8 +342,10 @@ def test_forgiving_add_column():
engine = create_engine("sqlite://", poolclass=StaticPool) engine = create_engine("sqlite://", poolclass=StaticPool)
with Session(engine) as session: with Session(engine) as session:
session.execute(text("CREATE TABLE hello (id int)")) session.execute(text("CREATE TABLE hello (id int)"))
migration._add_columns(session, "hello", ["context_id CHARACTER(36)"]) instance = Mock()
migration._add_columns(session, "hello", ["context_id CHARACTER(36)"]) instance.get_session = Mock(return_value=session)
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
def test_forgiving_add_index(): def test_forgiving_add_index():
@ -347,7 +353,9 @@ def test_forgiving_add_index():
engine = create_engine("sqlite://", poolclass=StaticPool) engine = create_engine("sqlite://", poolclass=StaticPool)
models.Base.metadata.create_all(engine) models.Base.metadata.create_all(engine)
with Session(engine) as session: with Session(engine) as session:
migration._create_index(session, "states", "ix_states_context_id") instance = Mock()
instance.get_session = Mock(return_value=session)
migration._create_index(instance, "states", "ix_states_context_id")
@pytest.mark.parametrize( @pytest.mark.parametrize(