mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Detect if mysql and sqlite support row_number (#57475)
This commit is contained in:
parent
3ff30f53a7
commit
0139bfa749
@ -413,6 +413,7 @@ class Recorder(threading.Thread):
|
|||||||
self.async_migration_event = asyncio.Event()
|
self.async_migration_event = asyncio.Event()
|
||||||
self.migration_in_progress = False
|
self.migration_in_progress = False
|
||||||
self._queue_watcher = None
|
self._queue_watcher = None
|
||||||
|
self._db_supports_row_number = True
|
||||||
|
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
|
|
||||||
@ -972,6 +973,7 @@ class Recorder(threading.Thread):
|
|||||||
def setup_recorder_connection(dbapi_connection, connection_record):
|
def setup_recorder_connection(dbapi_connection, connection_record):
|
||||||
"""Dbapi specific connection settings."""
|
"""Dbapi specific connection settings."""
|
||||||
setup_connection_for_dialect(
|
setup_connection_for_dialect(
|
||||||
|
self,
|
||||||
self.engine.dialect.name,
|
self.engine.dialect.name,
|
||||||
dbapi_connection,
|
dbapi_connection,
|
||||||
not self._completed_first_database_setup,
|
not self._completed_first_database_setup,
|
||||||
|
@ -89,6 +89,13 @@ QUERY_STATISTICS_SUMMARY_SUM = [
|
|||||||
.label("rownum"),
|
.label("rownum"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
QUERY_STATISTICS_SUMMARY_SUM_LEGACY = [
|
||||||
|
StatisticsShortTerm.metadata_id,
|
||||||
|
StatisticsShortTerm.last_reset,
|
||||||
|
StatisticsShortTerm.state,
|
||||||
|
StatisticsShortTerm.sum,
|
||||||
|
]
|
||||||
|
|
||||||
QUERY_STATISTIC_META = [
|
QUERY_STATISTIC_META = [
|
||||||
StatisticsMeta.id,
|
StatisticsMeta.id,
|
||||||
StatisticsMeta.statistic_id,
|
StatisticsMeta.statistic_id,
|
||||||
@ -275,6 +282,7 @@ def compile_hourly_statistics(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Get last hour's last sum
|
# Get last hour's last sum
|
||||||
|
if instance._db_supports_row_number: # pylint: disable=[protected-access]
|
||||||
subquery = (
|
subquery = (
|
||||||
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
|
session.query(*QUERY_STATISTICS_SUMMARY_SUM)
|
||||||
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
|
.filter(StatisticsShortTerm.start >= bindparam("start_time"))
|
||||||
@ -306,6 +314,49 @@ def compile_hourly_statistics(
|
|||||||
"state": state,
|
"state": state,
|
||||||
"sum": _sum,
|
"sum": _sum,
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
baked_query = instance.hass.data[STATISTICS_SHORT_TERM_BAKERY](
|
||||||
|
lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY)
|
||||||
|
)
|
||||||
|
|
||||||
|
baked_query += lambda q: q.filter(
|
||||||
|
StatisticsShortTerm.start >= bindparam("start_time")
|
||||||
|
)
|
||||||
|
baked_query += lambda q: q.filter(
|
||||||
|
StatisticsShortTerm.start < bindparam("end_time")
|
||||||
|
)
|
||||||
|
baked_query += lambda q: q.order_by(
|
||||||
|
StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc()
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = execute(
|
||||||
|
baked_query(session).params(start_time=start_time, end_time=end_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stats:
|
||||||
|
for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore
|
||||||
|
(
|
||||||
|
metadata_id,
|
||||||
|
last_reset,
|
||||||
|
state,
|
||||||
|
_sum,
|
||||||
|
) = next(group)
|
||||||
|
if metadata_id in summary:
|
||||||
|
summary[metadata_id].update(
|
||||||
|
{
|
||||||
|
"start": start_time,
|
||||||
|
"last_reset": process_timestamp(last_reset),
|
||||||
|
"state": state,
|
||||||
|
"sum": _sum,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
summary[metadata_id] = {
|
||||||
|
"start": start_time,
|
||||||
|
"last_reset": process_timestamp(last_reset),
|
||||||
|
"state": state,
|
||||||
|
"sum": _sum,
|
||||||
|
}
|
||||||
|
|
||||||
# Insert compiled hourly statistics in the database
|
# Insert compiled hourly statistics in the database
|
||||||
for metadata_id, stat in summary.items():
|
for metadata_id, stat in summary.items():
|
||||||
|
@ -266,7 +266,18 @@ def execute_on_connection(dbapi_connection, statement):
|
|||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connection):
|
def query_on_connection(dbapi_connection, statement):
|
||||||
|
"""Execute a single statement with a dbapi connection and return the result."""
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute(statement)
|
||||||
|
result = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def setup_connection_for_dialect(
|
||||||
|
instance, dialect_name, dbapi_connection, first_connection
|
||||||
|
):
|
||||||
"""Execute statements needed for dialect connection."""
|
"""Execute statements needed for dialect connection."""
|
||||||
# Returns False if the the connection needs to be setup
|
# Returns False if the the connection needs to be setup
|
||||||
# on the next connection, returns True if the connection
|
# on the next connection, returns True if the connection
|
||||||
@ -280,6 +291,13 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
|
|||||||
# WAL mode only needs to be setup once
|
# WAL mode only needs to be setup once
|
||||||
# instead of every time we open the sqlite connection
|
# instead of every time we open the sqlite connection
|
||||||
# as its persistent and isn't free to call every time.
|
# as its persistent and isn't free to call every time.
|
||||||
|
result = query_on_connection(dbapi_connection, "SELECT sqlite_version()")
|
||||||
|
version = result[0][0]
|
||||||
|
major, minor, _patch = version.split(".", 2)
|
||||||
|
if int(major) == 3 and int(minor) < 25:
|
||||||
|
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
# approximately 8MiB of memory
|
# approximately 8MiB of memory
|
||||||
execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192")
|
execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192")
|
||||||
@ -289,6 +307,14 @@ def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connectio
|
|||||||
|
|
||||||
if dialect_name == "mysql":
|
if dialect_name == "mysql":
|
||||||
execute_on_connection(dbapi_connection, "SET session wait_timeout=28800")
|
execute_on_connection(dbapi_connection, "SET session wait_timeout=28800")
|
||||||
|
if first_connection:
|
||||||
|
result = query_on_connection(dbapi_connection, "SELECT VERSION()")
|
||||||
|
version = result[0][0]
|
||||||
|
major, minor, _patch = version.split(".", 2)
|
||||||
|
if int(major) == 5 and int(minor) < 8:
|
||||||
|
instance._db_supports_row_number = ( # pylint: disable=[protected-access]
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def end_incomplete_runs(session, start_time):
|
def end_incomplete_runs(session, start_time):
|
||||||
|
@ -122,44 +122,88 @@ async def test_last_run_was_recently_clean(hass):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_setup_connection_for_dialect_mysql():
|
@pytest.mark.parametrize(
|
||||||
|
"mysql_version, db_supports_row_number",
|
||||||
|
[
|
||||||
|
("10.0.0", True),
|
||||||
|
("5.8.0", True),
|
||||||
|
("5.7.0", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_setup_connection_for_dialect_mysql(mysql_version, db_supports_row_number):
|
||||||
"""Test setting up the connection for a mysql dialect."""
|
"""Test setting up the connection for a mysql dialect."""
|
||||||
execute_mock = MagicMock()
|
instance_mock = MagicMock(_db_supports_row_number=True)
|
||||||
|
execute_args = []
|
||||||
close_mock = MagicMock()
|
close_mock = MagicMock()
|
||||||
|
|
||||||
|
def execute_mock(statement):
|
||||||
|
nonlocal execute_args
|
||||||
|
execute_args.append(statement)
|
||||||
|
|
||||||
|
def fetchall_mock():
|
||||||
|
nonlocal execute_args
|
||||||
|
if execute_args[-1] == "SELECT VERSION()":
|
||||||
|
return [[mysql_version]]
|
||||||
|
return None
|
||||||
|
|
||||||
def _make_cursor_mock(*_):
|
def _make_cursor_mock(*_):
|
||||||
return MagicMock(execute=execute_mock, close=close_mock)
|
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)
|
||||||
|
|
||||||
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
|
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
|
||||||
|
|
||||||
util.setup_connection_for_dialect("mysql", dbapi_connection, True)
|
util.setup_connection_for_dialect(instance_mock, "mysql", dbapi_connection, True)
|
||||||
|
|
||||||
assert execute_mock.call_args[0][0] == "SET session wait_timeout=28800"
|
assert len(execute_args) == 2
|
||||||
|
assert execute_args[0] == "SET session wait_timeout=28800"
|
||||||
|
assert execute_args[1] == "SELECT VERSION()"
|
||||||
|
|
||||||
|
assert instance_mock._db_supports_row_number == db_supports_row_number
|
||||||
|
|
||||||
|
|
||||||
def test_setup_connection_for_dialect_sqlite():
|
@pytest.mark.parametrize(
|
||||||
|
"sqlite_version, db_supports_row_number",
|
||||||
|
[
|
||||||
|
("3.25.0", True),
|
||||||
|
("3.24.0", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_setup_connection_for_dialect_sqlite(sqlite_version, db_supports_row_number):
|
||||||
"""Test setting up the connection for a sqlite dialect."""
|
"""Test setting up the connection for a sqlite dialect."""
|
||||||
execute_mock = MagicMock()
|
instance_mock = MagicMock(_db_supports_row_number=True)
|
||||||
|
execute_args = []
|
||||||
close_mock = MagicMock()
|
close_mock = MagicMock()
|
||||||
|
|
||||||
|
def execute_mock(statement):
|
||||||
|
nonlocal execute_args
|
||||||
|
execute_args.append(statement)
|
||||||
|
|
||||||
|
def fetchall_mock():
|
||||||
|
nonlocal execute_args
|
||||||
|
if execute_args[-1] == "SELECT sqlite_version()":
|
||||||
|
return [[sqlite_version]]
|
||||||
|
return None
|
||||||
|
|
||||||
def _make_cursor_mock(*_):
|
def _make_cursor_mock(*_):
|
||||||
return MagicMock(execute=execute_mock, close=close_mock)
|
return MagicMock(execute=execute_mock, close=close_mock, fetchall=fetchall_mock)
|
||||||
|
|
||||||
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
|
dbapi_connection = MagicMock(cursor=_make_cursor_mock)
|
||||||
|
|
||||||
util.setup_connection_for_dialect("sqlite", dbapi_connection, True)
|
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, True)
|
||||||
|
|
||||||
assert len(execute_mock.call_args_list) == 3
|
assert len(execute_args) == 4
|
||||||
assert execute_mock.call_args_list[0][0][0] == "PRAGMA journal_mode=WAL"
|
assert execute_args[0] == "PRAGMA journal_mode=WAL"
|
||||||
assert execute_mock.call_args_list[1][0][0] == "PRAGMA cache_size = -8192"
|
assert execute_args[1] == "SELECT sqlite_version()"
|
||||||
assert execute_mock.call_args_list[2][0][0] == "PRAGMA foreign_keys=ON"
|
assert execute_args[2] == "PRAGMA cache_size = -8192"
|
||||||
|
assert execute_args[3] == "PRAGMA foreign_keys=ON"
|
||||||
|
|
||||||
execute_mock.reset_mock()
|
execute_args = []
|
||||||
util.setup_connection_for_dialect("sqlite", dbapi_connection, False)
|
util.setup_connection_for_dialect(instance_mock, "sqlite", dbapi_connection, False)
|
||||||
|
|
||||||
assert len(execute_mock.call_args_list) == 2
|
assert len(execute_args) == 2
|
||||||
assert execute_mock.call_args_list[0][0][0] == "PRAGMA cache_size = -8192"
|
assert execute_args[0] == "PRAGMA cache_size = -8192"
|
||||||
assert execute_mock.call_args_list[1][0][0] == "PRAGMA foreign_keys=ON"
|
assert execute_args[1] == "PRAGMA foreign_keys=ON"
|
||||||
|
|
||||||
|
assert instance_mock._db_supports_row_number == db_supports_row_number
|
||||||
|
|
||||||
|
|
||||||
def test_basic_sanity_check(hass_recorder):
|
def test_basic_sanity_check(hass_recorder):
|
||||||
|
@ -1806,7 +1806,13 @@ def test_compile_hourly_statistics_changing_statistics(
|
|||||||
assert "Error while processing event StatisticsTask" not in caplog.text
|
assert "Error while processing event StatisticsTask" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def test_compile_statistics_hourly_summary(hass_recorder, caplog):
|
@pytest.mark.parametrize(
|
||||||
|
"db_supports_row_number,in_log,not_in_log",
|
||||||
|
[(True, "row_number", None), (False, None, "row_number")],
|
||||||
|
)
|
||||||
|
def test_compile_statistics_hourly_summary(
|
||||||
|
hass_recorder, caplog, db_supports_row_number, in_log, not_in_log
|
||||||
|
):
|
||||||
"""Test compiling hourly statistics."""
|
"""Test compiling hourly statistics."""
|
||||||
zero = dt_util.utcnow()
|
zero = dt_util.utcnow()
|
||||||
zero = zero.replace(minute=0, second=0, microsecond=0)
|
zero = zero.replace(minute=0, second=0, microsecond=0)
|
||||||
@ -1815,6 +1821,7 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
|
|||||||
zero += timedelta(hours=1)
|
zero += timedelta(hours=1)
|
||||||
hass = hass_recorder()
|
hass = hass_recorder()
|
||||||
recorder = hass.data[DATA_INSTANCE]
|
recorder = hass.data[DATA_INSTANCE]
|
||||||
|
recorder._db_supports_row_number = db_supports_row_number
|
||||||
setup_component(hass, "sensor", {})
|
setup_component(hass, "sensor", {})
|
||||||
attributes = {
|
attributes = {
|
||||||
"device_class": None,
|
"device_class": None,
|
||||||
@ -2052,6 +2059,10 @@ def test_compile_statistics_hourly_summary(hass_recorder, caplog):
|
|||||||
end += timedelta(hours=1)
|
end += timedelta(hours=1)
|
||||||
assert stats == expected_stats
|
assert stats == expected_stats
|
||||||
assert "Error while processing event StatisticsTask" not in caplog.text
|
assert "Error while processing event StatisticsTask" not in caplog.text
|
||||||
|
if in_log:
|
||||||
|
assert in_log in caplog.text
|
||||||
|
if not_in_log:
|
||||||
|
assert not_in_log not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def record_states(hass, zero, entity_id, attributes, seq=None):
|
def record_states(hass, zero, entity_id, attributes, seq=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user