diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index d7358e96100..091bff8445f 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -331,7 +331,7 @@ class Recorder(threading.Thread): self._pending_expunge = [] self.event_session = None self.get_session = None - self._completed_database_setup = None + self._completed_first_database_setup = None self._event_listener = None self.async_migration_event = asyncio.Event() self.migration_in_progress = False @@ -837,15 +837,16 @@ class Recorder(threading.Thread): def _setup_connection(self): """Ensure database is ready to fly.""" kwargs = {} - self._completed_database_setup = False + self._completed_first_database_setup = False def setup_recorder_connection(dbapi_connection, connection_record): """Dbapi specific connection settings.""" - if self._completed_database_setup: - return - self._completed_database_setup = setup_connection_for_dialect( - self.engine.dialect.name, dbapi_connection + setup_connection_for_dialect( + self.engine.dialect.name, + dbapi_connection, + not self._completed_first_database_setup, ) + self._completed_first_database_setup = True if self.db_url == SQLITE_URL_PREFIX or ":memory:" in self.db_url: kwargs["connect_args"] = {"check_same_thread": False} diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 6231f493cc2..186aad4fe9e 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -254,26 +254,27 @@ def execute_on_connection(dbapi_connection, statement): cursor.close() -def setup_connection_for_dialect(dialect_name, dbapi_connection): +def setup_connection_for_dialect(dialect_name, dbapi_connection, first_connection): """Execute statements needed for dialect connection.""" # Returns False if the the connection needs to be setup # on the next connection, returns True if the connection # never needs to be setup again. if dialect_name == "sqlite": - old_isolation = dbapi_connection.isolation_level - dbapi_connection.isolation_level = None - execute_on_connection(dbapi_connection, "PRAGMA journal_mode=WAL") - dbapi_connection.isolation_level = old_isolation - # WAL mode only needs to be setup once - # instead of every time we open the sqlite connection - # as its persistent and isn't free to call every time. - return True + if first_connection: + old_isolation = dbapi_connection.isolation_level + dbapi_connection.isolation_level = None + execute_on_connection(dbapi_connection, "PRAGMA journal_mode=WAL") + dbapi_connection.isolation_level = old_isolation + # WAL mode only needs to be setup once + # instead of every time we open the sqlite connection + # as its persistent and isn't free to call every time. + + # approximately 8MiB of memory + execute_on_connection(dbapi_connection, "PRAGMA cache_size = -8192") if dialect_name == "mysql": execute_on_connection(dbapi_connection, "SET session wait_timeout=28800") - return False - def end_incomplete_runs(session, start_time): """End any incomplete recorder runs.""" diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index b5c5b68fe3f..0a9f90be83e 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -152,7 +152,7 @@ def test_setup_connection_for_dialect_mysql(): dbapi_connection = MagicMock(cursor=_make_cursor_mock) - assert util.setup_connection_for_dialect("mysql", dbapi_connection) is False + util.setup_connection_for_dialect("mysql", dbapi_connection, True) assert execute_mock.call_args[0][0] == "SET session wait_timeout=28800" @@ -167,9 +167,17 @@ def test_setup_connection_for_dialect_sqlite(): dbapi_connection = MagicMock(cursor=_make_cursor_mock) - assert util.setup_connection_for_dialect("sqlite", dbapi_connection) is True + util.setup_connection_for_dialect("sqlite", dbapi_connection, True) - assert execute_mock.call_args[0][0] == "PRAGMA journal_mode=WAL" + assert len(execute_mock.call_args_list) == 2 + assert execute_mock.call_args_list[0][0][0] == "PRAGMA journal_mode=WAL" + assert execute_mock.call_args_list[1][0][0] == "PRAGMA cache_size = -8192" + + execute_mock.reset_mock() + util.setup_connection_for_dialect("sqlite", dbapi_connection, False) + + assert len(execute_mock.call_args_list) == 1 + assert execute_mock.call_args_list[0][0][0] == "PRAGMA cache_size = -8192" def test_basic_sanity_check(hass_recorder):