diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 419ef7aa117..38eed25bee8 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -25,6 +25,7 @@ from sqlalchemy.schema import AddConstraint, DropConstraint from sqlalchemy.sql.expression import true from homeassistant.core import HomeAssistant +from homeassistant.util.enum import try_parse_enum from homeassistant.util.ulid import ulid_to_bytes from .const import SupportedDialect @@ -84,6 +85,38 @@ _EMPTY_EVENT_TYPE = "missing_event_type" _LOGGER = logging.getLogger(__name__) +@dataclass +class _ColumnTypesForDialect: + big_int_type: str + timestamp_type: str + context_bin_type: str + + +_MYSQL_COLUMN_TYPES = _ColumnTypesForDialect( + big_int_type="INTEGER(20)", + timestamp_type="DOUBLE PRECISION", + context_bin_type=f"BLOB({CONTEXT_ID_BIN_MAX_LENGTH})", +) + +_POSTGRESQL_COLUMN_TYPES = _ColumnTypesForDialect( + big_int_type="INTEGER", + timestamp_type="DOUBLE PRECISION", + context_bin_type="BYTEA", +) + +_SQLITE_COLUMN_TYPES = _ColumnTypesForDialect( + big_int_type="INTEGER", + timestamp_type="FLOAT", + context_bin_type="BLOB", +) + +_COLUMN_TYPES_FOR_DIALECT: dict[SupportedDialect | None, _ColumnTypesForDialect] = { + SupportedDialect.MYSQL: _MYSQL_COLUMN_TYPES, + SupportedDialect.POSTGRESQL: _POSTGRESQL_COLUMN_TYPES, + SupportedDialect.SQLITE: _SQLITE_COLUMN_TYPES, +} + + def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) -> None: """Raise if the exception and cause do not contain the match substrs.""" lower_ex_strs = [str(ex).lower(), str(ex.__cause__).lower()] @@ -544,18 +577,9 @@ def _apply_update( # noqa: C901 old_version: int, ) -> None: """Perform operations to bring schema up to date.""" - dialect = engine.dialect.name - big_int = "INTEGER(20)" if dialect == SupportedDialect.MYSQL else "INTEGER" - if dialect == SupportedDialect.MYSQL: - timestamp_type = "DOUBLE PRECISION" - context_bin_type = f"BLOB({CONTEXT_ID_BIN_MAX_LENGTH})" - elif dialect == SupportedDialect.POSTGRESQL: - timestamp_type = "DOUBLE PRECISION" - context_bin_type = "BYTEA" - else: - timestamp_type = "FLOAT" - context_bin_type = "BLOB" - + assert engine.dialect.name is not None, "Dialect name must be set" + dialect = try_parse_enum(SupportedDialect, engine.dialect.name) + _column_types = _COLUMN_TYPES_FOR_DIALECT.get(dialect, _SQLITE_COLUMN_TYPES) if new_version == 1: # This used to create ix_events_time_fired, but it was removed in version 32 pass @@ -817,12 +841,14 @@ def _apply_update( # noqa: C901 # of removing any duplicate if they still exist. pass elif new_version == 25: - _add_columns(session_maker, "states", [f"attributes_id {big_int}"]) + _add_columns( + session_maker, "states", [f"attributes_id {_column_types.big_int_type}"] + ) _create_index(session_maker, "states", "ix_states_attributes_id") elif new_version == 26: _create_index(session_maker, "statistics_runs", "ix_statistics_runs_start") elif new_version == 27: - _add_columns(session_maker, "events", [f"data_id {big_int}"]) + _add_columns(session_maker, "events", [f"data_id {_column_types.big_int_type}"]) _create_index(session_maker, "events", "ix_events_data_id") elif new_version == 28: _add_columns(session_maker, "events", ["origin_idx INTEGER"]) @@ -881,11 +907,16 @@ def _apply_update( # noqa: C901 # ALTER TABLE events DROP COLUMN time_fired # ALTER TABLE states DROP COLUMN last_updated # ALTER TABLE states DROP COLUMN last_changed - _add_columns(session_maker, "events", [f"time_fired_ts {timestamp_type}"]) + _add_columns( + session_maker, "events", [f"time_fired_ts {_column_types.timestamp_type}"] + ) _add_columns( session_maker, "states", - [f"last_updated_ts {timestamp_type}", f"last_changed_ts {timestamp_type}"], + [ + f"last_updated_ts {_column_types.timestamp_type}", + f"last_changed_ts {_column_types.timestamp_type}", + ], ) _create_index(session_maker, "events", "ix_events_time_fired_ts") _create_index(session_maker, "events", "ix_events_event_type_time_fired_ts") @@ -917,18 +948,18 @@ def _apply_update( # noqa: C901 session_maker, "statistics", [ - f"created_ts {timestamp_type}", - f"start_ts {timestamp_type}", - f"last_reset_ts {timestamp_type}", + f"created_ts {_column_types.timestamp_type}", + f"start_ts {_column_types.timestamp_type}", + f"last_reset_ts {_column_types.timestamp_type}", ], ) _add_columns( session_maker, "statistics_short_term", [ - f"created_ts {timestamp_type}", - f"start_ts {timestamp_type}", - f"last_reset_ts {timestamp_type}", + f"created_ts {_column_types.timestamp_type}", + f"start_ts {_column_types.timestamp_type}", + f"last_reset_ts {_column_types.timestamp_type}", ], ) _create_index(session_maker, "statistics", "ix_statistics_start_ts") @@ -983,20 +1014,24 @@ def _apply_update( # noqa: C901 session_maker, table, [ - f"context_id_bin {context_bin_type}", - f"context_user_id_bin {context_bin_type}", - f"context_parent_id_bin {context_bin_type}", + f"context_id_bin {_column_types.context_bin_type}", + f"context_user_id_bin {_column_types.context_bin_type}", + f"context_parent_id_bin {_column_types.context_bin_type}", ], ) _create_index(session_maker, "events", "ix_events_context_id_bin") _create_index(session_maker, "states", "ix_states_context_id_bin") elif new_version == 37: - _add_columns(session_maker, "events", [f"event_type_id {big_int}"]) + _add_columns( + session_maker, "events", [f"event_type_id {_column_types.big_int_type}"] + ) _create_index(session_maker, "events", "ix_events_event_type_id") _drop_index(session_maker, "events", "ix_events_event_type_time_fired_ts") _create_index(session_maker, "events", "ix_events_event_type_id_time_fired_ts") elif new_version == 38: - _add_columns(session_maker, "states", [f"metadata_id {big_int}"]) + _add_columns( + session_maker, "states", [f"metadata_id {_column_types.big_int_type}"] + ) _create_index(session_maker, "states", "ix_states_metadata_id") _create_index(session_maker, "states", "ix_states_metadata_id_last_updated_ts") elif new_version == 39: