Make Recorder dialect_name a cached_property (#117922)

This commit is contained in:
J. Nick Koston 2024-05-28 21:23:40 -10:00 committed by GitHub
parent f7d2d94fdc
commit 0888233f06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 87 additions and 53 deletions

View File

@ -7,6 +7,7 @@ from collections.abc import Callable, Iterable
from concurrent.futures import CancelledError from concurrent.futures import CancelledError
import contextlib import contextlib
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import cached_property
import logging import logging
import queue import queue
import sqlite3 import sqlite3
@ -258,7 +259,7 @@ class Recorder(threading.Thread):
"""Return the number of items in the recorder backlog.""" """Return the number of items in the recorder backlog."""
return self._queue.qsize() return self._queue.qsize()
@property @cached_property
def dialect_name(self) -> SupportedDialect | None: def dialect_name(self) -> SupportedDialect | None:
"""Return the dialect the recorder uses.""" """Return the dialect the recorder uses."""
return self._dialect_name return self._dialect_name
@ -1446,6 +1447,7 @@ class Recorder(threading.Thread):
self.engine = create_engine(self.db_url, **kwargs, future=True) self.engine = create_engine(self.db_url, **kwargs, future=True)
self._dialect_name = try_parse_enum(SupportedDialect, self.engine.dialect.name) self._dialect_name = try_parse_enum(SupportedDialect, self.engine.dialect.name)
self.__dict__.pop("dialect_name", None)
sqlalchemy_event.listen(self.engine, "connect", self._setup_recorder_connection) sqlalchemy_event.listen(self.engine, "connect", self._setup_recorder_connection)
Base.metadata.create_all(self.engine) Base.metadata.create_all(self.engine)

View File

@ -17,16 +17,14 @@ async def test_validate_db_schema_fix_float_issue(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine, db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with postgresql and mysql. """Test validating DB schema with postgresql and mysql.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_db_schema_precision", "homeassistant.components.recorder.auto_repairs.schema._validate_db_schema_precision",
return_value={"events.double precision"}, return_value={"events.double precision"},
@ -50,17 +48,19 @@ async def test_validate_db_schema_fix_float_issue(
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_validate_db_schema_fix_utf8_issue_event_data( async def test_validate_db_schema_fix_utf8_issue_event_data(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8",
return_value={"event_data.4-byte UTF-8"}, return_value={"event_data.4-byte UTF-8"},
@ -81,17 +81,19 @@ async def test_validate_db_schema_fix_utf8_issue_event_data(
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_validate_db_schema_fix_collation_issue( async def test_validate_db_schema_fix_collation_issue(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_has_correct_collation", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_has_correct_collation",
return_value={"events.utf8mb4_unicode_ci"}, return_value={"events.utf8mb4_unicode_ci"},

View File

@ -17,16 +17,14 @@ async def test_validate_db_schema_fix_float_issue(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine, db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with postgresql and mysql. """Test validating DB schema with postgresql and mysql.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_db_schema_precision", "homeassistant.components.recorder.auto_repairs.schema._validate_db_schema_precision",
return_value={"states.double precision"}, return_value={"states.double precision"},
@ -52,17 +50,19 @@ async def test_validate_db_schema_fix_float_issue(
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_validate_db_schema_fix_utf8_issue_states( async def test_validate_db_schema_fix_utf8_issue_states(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8",
return_value={"states.4-byte UTF-8"}, return_value={"states.4-byte UTF-8"},
@ -82,17 +82,19 @@ async def test_validate_db_schema_fix_utf8_issue_states(
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_validate_db_schema_fix_utf8_issue_state_attributes( async def test_validate_db_schema_fix_utf8_issue_state_attributes(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8",
return_value={"state_attributes.4-byte UTF-8"}, return_value={"state_attributes.4-byte UTF-8"},
@ -113,17 +115,19 @@ async def test_validate_db_schema_fix_utf8_issue_state_attributes(
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_validate_db_schema_fix_collation_issue( async def test_validate_db_schema_fix_collation_issue(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_has_correct_collation", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_has_correct_collation",
return_value={"states.utf8mb4_unicode_ci"}, return_value={"states.utf8mb4_unicode_ci"},

View File

@ -11,18 +11,20 @@ from ...common import async_wait_recording_done
from tests.typing import RecorderInstanceGenerator from tests.typing import RecorderInstanceGenerator
@pytest.mark.parametrize("db_engine", ["mysql"])
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
async def test_validate_db_schema_fix_utf8_issue( async def test_validate_db_schema_fix_utf8_issue(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_supports_utf8",
return_value={"statistics_meta.4-byte UTF-8"}, return_value={"statistics_meta.4-byte UTF-8"},
@ -51,15 +53,13 @@ async def test_validate_db_schema_fix_float_issue(
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
table: str, table: str,
db_engine: str, db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with postgresql and mysql. """Test validating DB schema with postgresql and mysql.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_db_schema_precision", "homeassistant.components.recorder.auto_repairs.schema._validate_db_schema_precision",
return_value={f"{table}.double precision"}, return_value={f"{table}.double precision"},
@ -90,17 +90,19 @@ async def test_validate_db_schema_fix_float_issue(
@pytest.mark.parametrize("enable_schema_validation", [True]) @pytest.mark.parametrize("enable_schema_validation", [True])
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_validate_db_schema_fix_collation_issue( async def test_validate_db_schema_fix_collation_issue(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
recorder_dialect_name: None,
db_engine: str,
) -> None: ) -> None:
"""Test validating DB schema with MySQL. """Test validating DB schema with MySQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with ( with (
patch("homeassistant.components.recorder.core.Recorder.dialect_name", "mysql"),
patch( patch(
"homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_has_correct_collation", "homeassistant.components.recorder.auto_repairs.schema._validate_table_schema_has_correct_collation",
return_value={"statistics.utf8mb4_unicode_ci"}, return_value={"statistics.utf8mb4_unicode_ci"},

View File

@ -1,7 +1,5 @@
"""The test validating and repairing schema.""" """The test validating and repairing schema."""
from unittest.mock import patch
import pytest import pytest
from sqlalchemy import text from sqlalchemy import text
@ -28,15 +26,13 @@ async def test_validate_db_schema(
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant, hass: HomeAssistant,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
db_engine, db_engine: str,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test validating DB schema with MySQL and PostgreSQL. """Test validating DB schema with MySQL and PostgreSQL.
Note: The test uses SQLite, the purpose is only to exercise the code. Note: The test uses SQLite, the purpose is only to exercise the code.
""" """
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", db_engine
):
await async_setup_recorder_instance(hass) await async_setup_recorder_instance(hass)
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
assert "Schema validation failed" not in caplog.text assert "Schema validation failed" not in caplog.text

View File

@ -0,0 +1,26 @@
"""Fixtures for the recorder component tests."""
from collections.abc import Generator
from unittest.mock import patch
import pytest
from homeassistant.components import recorder
from homeassistant.core import HomeAssistant
@pytest.fixture
def recorder_dialect_name(
hass: HomeAssistant, db_engine: str
) -> Generator[None, None, None]:
"""Patch the recorder dialect."""
if instance := hass.data.get(recorder.DATA_INSTANCE):
instance.__dict__.pop("dialect_name", None)
with patch.object(instance, "_dialect_name", db_engine):
yield
instance.__dict__.pop("dialect_name", None)
else:
with patch(
"homeassistant.components.recorder.Recorder.dialect_name", db_engine
):
yield

View File

@ -8,7 +8,7 @@ from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
import sqlite3 import sqlite3
import threading import threading
from typing import cast from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@ -293,7 +293,7 @@ async def test_saving_state(hass: HomeAssistant, setup_recorder: None) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dialect_name", "expected_attributes"), ("db_engine", "expected_attributes"),
[ [
(SupportedDialect.MYSQL, {"test_attr": 5, "test_attr_10": "silly\0stuff"}), (SupportedDialect.MYSQL, {"test_attr": 5, "test_attr_10": "silly\0stuff"}),
(SupportedDialect.POSTGRESQL, {"test_attr": 5, "test_attr_10": "silly"}), (SupportedDialect.POSTGRESQL, {"test_attr": 5, "test_attr_10": "silly"}),
@ -301,16 +301,17 @@ async def test_saving_state(hass: HomeAssistant, setup_recorder: None) -> None:
], ],
) )
async def test_saving_state_with_nul( async def test_saving_state_with_nul(
hass: HomeAssistant, setup_recorder: None, dialect_name, expected_attributes hass: HomeAssistant,
db_engine: str,
recorder_dialect_name: None,
setup_recorder: None,
expected_attributes: dict[str, Any],
) -> None: ) -> None:
"""Test saving and restoring a state with nul in attributes.""" """Test saving and restoring a state with nul in attributes."""
entity_id = "test.recorder" entity_id = "test.recorder"
state = "restoring_from_db" state = "restoring_from_db"
attributes = {"test_attr": 5, "test_attr_10": "silly\0stuff"} attributes = {"test_attr": 5, "test_attr_10": "silly\0stuff"}
with patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
):
hass.states.async_set(entity_id, state, attributes) hass.states.async_set(entity_id, state, attributes)
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
@ -2071,13 +2072,14 @@ async def test_in_memory_database(
assert "In-memory SQLite database is not supported" in caplog.text assert "In-memory SQLite database is not supported" in caplog.text
@pytest.mark.parametrize("db_engine", ["mysql"])
async def test_database_connection_keep_alive( async def test_database_connection_keep_alive(
hass: HomeAssistant, hass: HomeAssistant,
recorder_dialect_name: None,
async_setup_recorder_instance: RecorderInstanceGenerator, async_setup_recorder_instance: RecorderInstanceGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test we keep alive socket based dialects.""" """Test we keep alive socket based dialects."""
with patch("homeassistant.components.recorder.Recorder.dialect_name"):
instance = await async_setup_recorder_instance(hass) instance = await async_setup_recorder_instance(hass)
# We have to mock this since we don't have a mock # We have to mock this since we don't have a mock
# MySQL server available in tests. # MySQL server available in tests.

View File

@ -37,18 +37,18 @@ async def test_recorder_system_health(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dialect_name", [SupportedDialect.MYSQL, SupportedDialect.POSTGRESQL] "db_engine", [SupportedDialect.MYSQL, SupportedDialect.POSTGRESQL]
) )
async def test_recorder_system_health_alternate_dbms( async def test_recorder_system_health_alternate_dbms(
recorder_mock: Recorder, hass: HomeAssistant, dialect_name recorder_mock: Recorder,
hass: HomeAssistant,
db_engine: SupportedDialect,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test recorder system health.""" """Test recorder system health."""
assert await async_setup_component(hass, "system_health", {}) assert await async_setup_component(hass, "system_health", {})
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
with ( with (
patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
),
patch( patch(
"sqlalchemy.orm.session.Session.execute", "sqlalchemy.orm.session.Session.execute",
return_value=Mock(scalar=Mock(return_value=("1048576"))), return_value=Mock(scalar=Mock(return_value=("1048576"))),
@ -60,16 +60,19 @@ async def test_recorder_system_health_alternate_dbms(
"current_recorder_run": instance.recorder_runs_manager.current.start, "current_recorder_run": instance.recorder_runs_manager.current.start,
"oldest_recorder_run": instance.recorder_runs_manager.first.start, "oldest_recorder_run": instance.recorder_runs_manager.first.start,
"estimated_db_size": "1.00 MiB", "estimated_db_size": "1.00 MiB",
"database_engine": dialect_name.value, "database_engine": db_engine.value,
"database_version": ANY, "database_version": ANY,
} }
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dialect_name", [SupportedDialect.MYSQL, SupportedDialect.POSTGRESQL] "db_engine", [SupportedDialect.MYSQL, SupportedDialect.POSTGRESQL]
) )
async def test_recorder_system_health_db_url_missing_host( async def test_recorder_system_health_db_url_missing_host(
recorder_mock: Recorder, hass: HomeAssistant, dialect_name recorder_mock: Recorder,
hass: HomeAssistant,
db_engine: SupportedDialect,
recorder_dialect_name: None,
) -> None: ) -> None:
"""Test recorder system health with a db_url without a hostname.""" """Test recorder system health with a db_url without a hostname."""
assert await async_setup_component(hass, "system_health", {}) assert await async_setup_component(hass, "system_health", {})
@ -77,9 +80,6 @@ async def test_recorder_system_health_db_url_missing_host(
instance = get_instance(hass) instance = get_instance(hass)
with ( with (
patch(
"homeassistant.components.recorder.core.Recorder.dialect_name", dialect_name
),
patch.object( patch.object(
instance, instance,
"db_url", "db_url",
@ -95,7 +95,7 @@ async def test_recorder_system_health_db_url_missing_host(
"current_recorder_run": instance.recorder_runs_manager.current.start, "current_recorder_run": instance.recorder_runs_manager.current.start,
"oldest_recorder_run": instance.recorder_runs_manager.first.start, "oldest_recorder_run": instance.recorder_runs_manager.first.start,
"estimated_db_size": "1.00 MiB", "estimated_db_size": "1.00 MiB",
"database_engine": dialect_name.value, "database_engine": db_engine.value,
"database_version": ANY, "database_version": ANY,
} }