Remove support for databases without ROW_NUMBER (#72092)

This commit is contained in:
Erik Montnemery 2022-05-19 04:52:38 +02:00 committed by GitHub
parent 3a13ffcf13
commit edd7a3427c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 130 deletions

View File

@ -180,7 +180,6 @@ class Recorder(threading.Thread):
self._completed_first_database_setup: bool | None = None self._completed_first_database_setup: bool | None = None
self.async_migration_event = asyncio.Event() self.async_migration_event = asyncio.Event()
self.migration_in_progress = False self.migration_in_progress = False
self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self._exclude_attributes_by_domain = exclude_attributes_by_domain self._exclude_attributes_by_domain = exclude_attributes_by_domain

View File

@ -437,22 +437,6 @@ def _compile_hourly_statistics_summary_mean_stmt(
return stmt return stmt
def _compile_hourly_statistics_summary_sum_legacy_stmt(
start_time: datetime, end_time: datetime
) -> StatementLambdaElement:
"""Generate the legacy sum statement for hourly statistics.
This is used for databases not supporting row number.
"""
stmt = lambda_stmt(lambda: select(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY))
stmt += (
lambda q: q.filter(StatisticsShortTerm.start >= start_time)
.filter(StatisticsShortTerm.start < end_time)
.order_by(StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc())
)
return stmt
def compile_hourly_statistics( def compile_hourly_statistics(
instance: Recorder, session: Session, start: datetime instance: Recorder, session: Session, start: datetime
) -> None: ) -> None:
@ -481,66 +465,37 @@ def compile_hourly_statistics(
} }
# Get last hour's last sum # Get last hour's last sum
if instance._db_supports_row_number: # pylint: disable=[protected-access] subquery = (
subquery = ( session.query(*QUERY_STATISTICS_SUMMARY_SUM)
session.query(*QUERY_STATISTICS_SUMMARY_SUM) .filter(StatisticsShortTerm.start >= bindparam("start_time"))
.filter(StatisticsShortTerm.start >= bindparam("start_time")) .filter(StatisticsShortTerm.start < bindparam("end_time"))
.filter(StatisticsShortTerm.start < bindparam("end_time")) .subquery()
.subquery() )
) query = (
query = ( session.query(subquery)
session.query(subquery) .filter(subquery.c.rownum == 1)
.filter(subquery.c.rownum == 1) .order_by(subquery.c.metadata_id)
.order_by(subquery.c.metadata_id) )
) stats = execute(query.params(start_time=start_time, end_time=end_time))
stats = execute(query.params(start_time=start_time, end_time=end_time))
if stats: if stats:
for stat in stats: for stat in stats:
metadata_id, start, last_reset, state, _sum, _ = stat metadata_id, start, last_reset, state, _sum, _ = stat
if metadata_id in summary: if metadata_id in summary:
summary[metadata_id].update( summary[metadata_id].update(
{ {
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
else:
stmt = _compile_hourly_statistics_summary_sum_legacy_stmt(start_time, end_time)
stats = execute_stmt_lambda_element(session, stmt)
if stats:
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore[no-any-return]
(
metadata_id,
last_reset,
state,
_sum,
) = next(group)
if metadata_id in summary:
summary[metadata_id].update(
{
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset), "last_reset": process_timestamp(last_reset),
"state": state, "state": state,
"sum": _sum, "sum": _sum,
} }
)
else:
summary[metadata_id] = {
"start": start_time,
"last_reset": process_timestamp(last_reset),
"state": state,
"sum": _sum,
}
# Insert compiled hourly statistics in the database # Insert compiled hourly statistics in the database
for metadata_id, stat in summary.items(): for metadata_id, stat in summary.items():

View File

@ -52,12 +52,9 @@ SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
DEFAULT_YIELD_STATES_ROWS = 32768 DEFAULT_YIELD_STATES_ROWS = 32768
MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_MARIA_DB = AwesomeVersion("10.3.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MARIA_DB_ROWNUM = AwesomeVersion("10.2.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MYSQL = AwesomeVersion("8.0.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_MYSQL = AwesomeVersion("8.0.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_MYSQL_ROWNUM = AwesomeVersion("5.8.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_PGSQL = AwesomeVersion("12.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_PGSQL = AwesomeVersion("12.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_SQLITE = AwesomeVersion("3.31.0", AwesomeVersionStrategy.SIMPLEVER) MIN_VERSION_SQLITE = AwesomeVersion("3.31.0", AwesomeVersionStrategy.SIMPLEVER)
MIN_VERSION_SQLITE_ROWNUM = AwesomeVersion("3.25.0", AwesomeVersionStrategy.SIMPLEVER)
# This is the maximum time after the recorder ends the session # This is the maximum time after the recorder ends the session
# before we no longer consider startup to be a "restart" and we # before we no longer consider startup to be a "restart" and we
@ -414,10 +411,6 @@ def setup_connection_for_dialect(
version_string = result[0][0] version_string = result[0][0]
version = _extract_version_from_server_response(version_string) version = _extract_version_from_server_response(version_string)
if version and version < MIN_VERSION_SQLITE_ROWNUM:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
if not version or version < MIN_VERSION_SQLITE: if not version or version < MIN_VERSION_SQLITE:
_fail_unsupported_version( _fail_unsupported_version(
version or version_string, "SQLite", MIN_VERSION_SQLITE version or version_string, "SQLite", MIN_VERSION_SQLITE
@ -448,19 +441,11 @@ def setup_connection_for_dialect(
is_maria_db = "mariadb" in version_string.lower() is_maria_db = "mariadb" in version_string.lower()
if is_maria_db: if is_maria_db:
if version and version < MIN_VERSION_MARIA_DB_ROWNUM:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
if not version or version < MIN_VERSION_MARIA_DB: if not version or version < MIN_VERSION_MARIA_DB:
_fail_unsupported_version( _fail_unsupported_version(
version or version_string, "MariaDB", MIN_VERSION_MARIA_DB version or version_string, "MariaDB", MIN_VERSION_MARIA_DB
) )
else: else:
if version and version < MIN_VERSION_MYSQL_ROWNUM:
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
False
)
if not version or version < MIN_VERSION_MYSQL: if not version or version < MIN_VERSION_MYSQL:
_fail_unsupported_version( _fail_unsupported_version(
version or version_string, "MySQL", MIN_VERSION_MYSQL version or version_string, "MySQL", MIN_VERSION_MYSQL

View File

@ -166,15 +166,12 @@ async def test_last_run_was_recently_clean(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mysql_version, db_supports_row_number", "mysql_version",
[ ["10.3.0-MariaDB", "8.0.0"],
("10.3.0-MariaDB", True),
("8.0.0", True),
],
) )
def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number): def test_setup_connection_for_dialect_mysql(mysql_version):
"""Test setting up the connection for a mysql dialect.""" """Test setting up the connection for a mysql dialect."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -199,18 +196,14 @@ def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_numbe
assert execute_args[0] == "SET session wait_timeout=28800" assert execute_args[0] == "SET session wait_timeout=28800"
assert execute_args[1] == "SELECT VERSION()" assert execute_args[1] == "SELECT VERSION()"
assert instance_mock._db_supports_row_number == db_supports_row_number
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sqlite_version, db_supports_row_number", "sqlite_version",
[ ["3.31.0"],
("3.31.0", True),
],
) )
def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number): def test_setup_connection_for_dialect_sqlite(sqlite_version):
"""Test setting up the connection for a sqlite dialect.""" """Test setting up the connection for a sqlite dialect."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -246,20 +239,16 @@ def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_num
assert execute_args[1] == "PRAGMA synchronous=NORMAL" assert execute_args[1] == "PRAGMA synchronous=NORMAL"
assert execute_args[2] == "PRAGMA foreign_keys=ON" assert execute_args[2] == "PRAGMA foreign_keys=ON"
assert instance_mock._db_supports_row_number == db_supports_row_number
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sqlite_version, db_supports_row_number", "sqlite_version",
[ ["3.31.0"],
("3.31.0", True),
],
) )
def test_setup_connection_for_dialect_sqlite_zero_commit_interval( def test_setup_connection_for_dialect_sqlite_zero_commit_interval(
sqlite_version, db_supports_row_number sqlite_version,
): ):
"""Test setting up the connection for a sqlite dialect with a zero commit interval.""" """Test setting up the connection for a sqlite dialect with a zero commit interval."""
instance_mock = MagicMock(_db_supports_row_number=True, commit_interval=0) instance_mock = MagicMock(commit_interval=0)
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -295,8 +284,6 @@ def test_setup_connection_for_dialect_sqlite_zero_commit_interval(
assert execute_args[1] == "PRAGMA synchronous=FULL" assert execute_args[1] == "PRAGMA synchronous=FULL"
assert execute_args[2] == "PRAGMA foreign_keys=ON" assert execute_args[2] == "PRAGMA foreign_keys=ON"
assert instance_mock._db_supports_row_number == db_supports_row_number
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mysql_version,message", "mysql_version,message",
@ -317,7 +304,7 @@ def test_setup_connection_for_dialect_sqlite_zero_commit_interval(
) )
def test_fail_outdated_mysql(caplog, mysql_version, message): def test_fail_outdated_mysql(caplog, mysql_version, message):
"""Test setting up the connection for an outdated mysql version.""" """Test setting up the connection for an outdated mysql version."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -353,7 +340,7 @@ def test_fail_outdated_mysql(caplog, mysql_version, message):
) )
def test_supported_mysql(caplog, mysql_version): def test_supported_mysql(caplog, mysql_version):
"""Test setting up the connection for a supported mysql version.""" """Test setting up the connection for a supported mysql version."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -396,7 +383,7 @@ def test_supported_mysql(caplog, mysql_version):
) )
def test_fail_outdated_pgsql(caplog, pgsql_version, message): def test_fail_outdated_pgsql(caplog, pgsql_version, message):
"""Test setting up the connection for an outdated PostgreSQL version.""" """Test setting up the connection for an outdated PostgreSQL version."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -429,7 +416,7 @@ def test_fail_outdated_pgsql(caplog, pgsql_version, message):
) )
def test_supported_pgsql(caplog, pgsql_version): def test_supported_pgsql(caplog, pgsql_version):
"""Test setting up the connection for a supported PostgreSQL version.""" """Test setting up the connection for a supported PostgreSQL version."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -474,7 +461,7 @@ def test_supported_pgsql(caplog, pgsql_version):
) )
def test_fail_outdated_sqlite(caplog, sqlite_version, message): def test_fail_outdated_sqlite(caplog, sqlite_version, message):
"""Test setting up the connection for an outdated sqlite version.""" """Test setting up the connection for an outdated sqlite version."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()
@ -510,7 +497,7 @@ def test_fail_outdated_sqlite(caplog, sqlite_version, message):
) )
def test_supported_sqlite(caplog, sqlite_version): def test_supported_sqlite(caplog, sqlite_version):
"""Test setting up the connection for a supported sqlite version.""" """Test setting up the connection for a supported sqlite version."""
instance_mock = MagicMock(_db_supports_row_number=True) instance_mock = MagicMock()
execute_args = [] execute_args = []
close_mock = MagicMock() close_mock = MagicMock()

View File

@ -2279,13 +2279,7 @@ def test_compile_hourly_statistics_changing_statistics(
assert "Error while processing event StatisticsTask" not in caplog.text assert "Error while processing event StatisticsTask" not in caplog.text
@pytest.mark.parametrize( def test_compile_statistics_hourly_daily_monthly_summary(hass_recorder, caplog):
"db_supports_row_number,in_log,not_in_log",
[(True, "row_number", None), (False, None, "row_number")],
)
def test_compile_statistics_hourly_daily_monthly_summary(
hass_recorder, caplog, db_supports_row_number, in_log, not_in_log
):
"""Test compiling hourly statistics + monthly and daily summary.""" """Test compiling hourly statistics + monthly and daily summary."""
zero = dt_util.utcnow() zero = dt_util.utcnow()
# August 31st, 23:00 local time # August 31st, 23:00 local time
@ -2299,7 +2293,6 @@ def test_compile_statistics_hourly_daily_monthly_summary(
# Remove this after dropping the use of the hass_recorder fixture # Remove this after dropping the use of the hass_recorder fixture
hass.config.set_time_zone("America/Regina") hass.config.set_time_zone("America/Regina")
recorder = hass.data[DATA_INSTANCE] recorder = hass.data[DATA_INSTANCE]
recorder._db_supports_row_number = db_supports_row_number
setup_component(hass, "sensor", {}) setup_component(hass, "sensor", {})
wait_recording_done(hass) # Wait for the sensor recorder platform to be added wait_recording_done(hass) # Wait for the sensor recorder platform to be added
attributes = { attributes = {
@ -2693,10 +2686,6 @@ def test_compile_statistics_hourly_daily_monthly_summary(
assert stats == expected_stats assert stats == expected_stats
assert "Error while processing event StatisticsTask" not in caplog.text assert "Error while processing event StatisticsTask" not in caplog.text
if in_log:
assert in_log in caplog.text
if not_in_log:
assert not_in_log not in caplog.text
def record_states(hass, zero, entity_id, attributes, seq=None): def record_states(hass, zero, entity_id, attributes, seq=None):