Simplify recorder.migration._drop_foreign_key_constraints (#123968)

This commit is contained in:
Erik Montnemery 2024-08-15 18:58:52 +02:00 committed by GitHub
parent 46357519e0
commit 64a68b17f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 36 deletions

View File

@ -15,7 +15,6 @@ from uuid import UUID
import sqlalchemy import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text, update from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text, update
from sqlalchemy.engine import CursorResult, Engine from sqlalchemy.engine import CursorResult, Engine
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
from sqlalchemy.exc import ( from sqlalchemy.exc import (
DatabaseError, DatabaseError,
IntegrityError, IntegrityError,
@ -645,7 +644,7 @@ def _update_states_table_with_foreign_key_options(
def _drop_foreign_key_constraints( def _drop_foreign_key_constraints(
session_maker: Callable[[], Session], engine: Engine, table: str, column: str session_maker: Callable[[], Session], engine: Engine, table: str, column: str
) -> list[tuple[str, str, ReflectedForeignKeyConstraint]]: ) -> None:
"""Drop foreign key constraints for a table on specific columns. """Drop foreign key constraints for a table on specific columns.
This is not supported for SQLite because it does not support This is not supported for SQLite because it does not support
@ -658,11 +657,6 @@ def _drop_foreign_key_constraints(
) )
inspector = sqlalchemy.inspect(engine) inspector = sqlalchemy.inspect(engine)
dropped_constraints = [
(table, column, foreign_key)
for foreign_key in inspector.get_foreign_keys(table)
if foreign_key["name"] and foreign_key["constrained_columns"] == [column]
]
## Find matching named constraints and bind the ForeignKeyConstraints to the table ## Find matching named constraints and bind the ForeignKeyConstraints to the table
tmp_table = Table(table, MetaData()) tmp_table = Table(table, MetaData())
@ -685,8 +679,6 @@ def _drop_foreign_key_constraints(
) )
raise raise
return dropped_constraints
def _restore_foreign_key_constraints( def _restore_foreign_key_constraints(
session_maker: Callable[[], Session], session_maker: Callable[[], Session],

View File

@ -7,7 +7,9 @@ import sys
from unittest.mock import ANY, Mock, PropertyMock, call, patch from unittest.mock import ANY, Mock, PropertyMock, call, patch
import pytest import pytest
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, inspect, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
from sqlalchemy.exc import ( from sqlalchemy.exc import (
DatabaseError, DatabaseError,
InternalError, InternalError,
@ -973,31 +975,40 @@ def test_drop_restore_foreign_key_constraints(recorder_db_url: str) -> None:
], ],
} }
def find_constraints(
engine: Engine, table: str, column: str
) -> list[tuple[str, str, ReflectedForeignKeyConstraint]]:
inspector = inspect(engine)
return [
(table, column, foreign_key)
for foreign_key in inspector.get_foreign_keys(table)
if foreign_key["name"] and foreign_key["constrained_columns"] == [column]
]
engine = create_engine(recorder_db_url) engine = create_engine(recorder_db_url)
db_schema.Base.metadata.create_all(engine) db_schema.Base.metadata.create_all(engine)
matching_constraints_1 = [
dropped_constraint
for table, column, _, _ in constraints_to_recreate
for dropped_constraint in find_constraints(engine, table, column)
]
assert matching_constraints_1 == expected_dropped_constraints[db_engine]
with Session(engine) as session: with Session(engine) as session:
session_maker = Mock(return_value=session) session_maker = Mock(return_value=session)
dropped_constraints_1 = [ for table, column, _, _ in constraints_to_recreate:
dropped_constraint migration._drop_foreign_key_constraints(
for table, column, _, _ in constraints_to_recreate
for dropped_constraint in migration._drop_foreign_key_constraints(
session_maker, engine, table, column session_maker, engine, table, column
) )
]
assert dropped_constraints_1 == expected_dropped_constraints[db_engine]
# Check we don't find the constrained columns again (they are removed) # Check we don't find the constrained columns again (they are removed)
with Session(engine) as session: matching_constraints_2 = [
session_maker = Mock(return_value=session) dropped_constraint
dropped_constraints_2 = [ for table, column, _, _ in constraints_to_recreate
dropped_constraint for dropped_constraint in find_constraints(engine, table, column)
for table, column, _, _ in constraints_to_recreate ]
for dropped_constraint in migration._drop_foreign_key_constraints( assert matching_constraints_2 == []
session_maker, engine, table, column
)
]
assert dropped_constraints_2 == []
# Restore the constraints # Restore the constraints
with Session(engine) as session: with Session(engine) as session:
@ -1007,16 +1018,12 @@ def test_drop_restore_foreign_key_constraints(recorder_db_url: str) -> None:
) )
# Check we do find the constrained columns again (they are restored) # Check we do find the constrained columns again (they are restored)
with Session(engine) as session: matching_constraints_3 = [
session_maker = Mock(return_value=session) dropped_constraint
dropped_constraints_3 = [ for table, column, _, _ in constraints_to_recreate
dropped_constraint for dropped_constraint in find_constraints(engine, table, column)
for table, column, _, _ in constraints_to_recreate ]
for dropped_constraint in migration._drop_foreign_key_constraints( assert matching_constraints_3 == expected_dropped_constraints[db_engine]
session_maker, engine, table, column
)
]
assert dropped_constraints_3 == expected_dropped_constraints[db_engine]
engine.dispose() engine.dispose()