diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 750f504d096..7b43abd8dde 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -32,6 +32,7 @@ from .const import ( # noqa: F401 INTEGRATION_PLATFORM_EXCLUDE_ATTRIBUTES, INTEGRATION_PLATFORMS_LOAD_IN_RECORDER_THREAD, SQLITE_URL_PREFIX, + SupportedDialect, ) from .core import Recorder from .services import async_register_services diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index d52f2d10d0d..1c1ed6adae4 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -64,6 +64,7 @@ def validate_query(db_url: str, query: str, column: str) -> bool: if sess: sess.close() + engine.dispose() return True diff --git a/homeassistant/components/sql/models.py b/homeassistant/components/sql/models.py new file mode 100644 index 00000000000..feac9ebf20c --- /dev/null +++ b/homeassistant/components/sql/models.py @@ -0,0 +1,16 @@ +"""The sql integration models.""" +from __future__ import annotations + +from dataclasses import dataclass + +from sqlalchemy.orm import scoped_session + +from homeassistant.core import CALLBACK_TYPE + + +@dataclass(slots=True) +class SQLData: + """Data for the sql integration.""" + + shutdown_event_cancel: CALLBACK_TYPE + session_makers_by_db_url: dict[str, scoped_session] diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index b6cce467e1f..eb0e9c9c46b 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -13,7 +13,11 @@ from sqlalchemy.orm import Session, scoped_session, sessionmaker from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.util import LRUCache -from homeassistant.components.recorder import CONF_DB_URL, get_instance +from homeassistant.components.recorder import ( + CONF_DB_URL, + SupportedDialect, + get_instance, +) from homeassistant.components.sensor import ( CONF_STATE_CLASS, SensorDeviceClass, @@ -27,8 +31,9 @@ from homeassistant.const import ( CONF_UNIQUE_ID, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, + EVENT_HOMEASSISTANT_STOP, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import TemplateError from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.device_registry import DeviceEntryType @@ -38,6 +43,7 @@ from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import CONF_COLUMN_NAME, CONF_QUERY, DB_URL_RE, DOMAIN +from .models import SQLData from .util import resolve_db_url _LOGGER = logging.getLogger(__name__) @@ -127,6 +133,36 @@ async def async_setup_entry( ) +@callback +def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData: + """Get or initialize domain data.""" + if DOMAIN in hass.data: + sql_data: SQLData = hass.data[DOMAIN] + return sql_data + + session_makers_by_db_url: dict[str, scoped_session] = {} + + # + # Ensure we dispose of all engines at shutdown + # to avoid unclean disconnects + # + # Shutdown all sessions in the executor since they will + # do blocking I/O + # + def _shutdown_db_engines(event: Event) -> None: + """Shutdown all database engines.""" + for sessmaker in session_makers_by_db_url.values(): + sessmaker.connection().engine.dispose() + + cancel_shutdown = hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_STOP, _shutdown_db_engines + ) + + sql_data = SQLData(cancel_shutdown, session_makers_by_db_url) + hass.data[DOMAIN] = sql_data + return sql_data + + async def async_setup_sensor( hass: HomeAssistant, name: str, @@ -144,18 +180,30 @@ async def async_setup_sensor( """Set up the SQL sensor.""" instance = get_instance(hass) sessmaker: scoped_session | None - if use_database_executor := (db_url == instance.db_url): + sql_data = _async_get_or_init_domain_data(hass) + uses_recorder_db = db_url == instance.db_url + use_database_executor = False + if uses_recorder_db and instance.dialect_name == SupportedDialect.SQLITE: + use_database_executor = True assert instance.engine is not None sessmaker = scoped_session(sessionmaker(bind=instance.engine, future=True)) - elif not ( - sessmaker := await hass.async_add_executor_job( - _validate_and_get_session_maker_for_db_url, db_url - ) + # For other databases we need to create a new engine since + # we want the connection to use the default timezone and these + # database engines will use QueuePool as its only sqlite that + # needs our custom pool. If there is already a session maker + # for this db_url we can use that so we do not create a new engine + # for every sensor. + elif db_url in sql_data.session_makers_by_db_url: + sessmaker = sql_data.session_makers_by_db_url[db_url] + elif sessmaker := await hass.async_add_executor_job( + _validate_and_get_session_maker_for_db_url, db_url ): + sql_data.session_makers_by_db_url[db_url] = sessmaker + else: return upper_query = query_str.upper() - if use_database_executor: + if uses_recorder_db: redacted_query = redact_credentials(query_str) issue_key = unique_id if unique_id else redacted_query diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index 7e289565b37..cd123556daf 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -407,3 +407,53 @@ async def test_no_issue_when_view_has_the_text_entity_id_in_it( "Query contains entity_id but does not reference states_meta" not in caplog.text ) assert hass.states.get("sensor.get_entity_id") is not None + + +async def test_multiple_sensors_using_same_db( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: + """Test multiple sensors using the same db.""" + config = { + "db_url": "sqlite:///", + "query": "SELECT 5 as value", + "column": "value", + "name": "Select value SQL query", + } + config2 = { + "db_url": "sqlite:///", + "query": "SELECT 5 as value", + "column": "value", + "name": "Select value SQL query 2", + } + await init_integration(hass, config) + await init_integration(hass, config2, entry_id="2") + + state = hass.states.get("sensor.select_value_sql_query") + assert state.state == "5" + assert state.attributes["value"] == 5 + + state = hass.states.get("sensor.select_value_sql_query_2") + assert state.state == "5" + assert state.attributes["value"] == 5 + + +async def test_engine_is_disposed_at_stop( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: + """Test we dispose of the engine at stop.""" + config = { + "db_url": "sqlite:///", + "query": "SELECT 5 as value", + "column": "value", + "name": "Select value SQL query", + } + await init_integration(hass, config) + + state = hass.states.get("sensor.select_value_sql_query") + assert state.state == "5" + assert state.attributes["value"] == 5 + + with patch("sqlalchemy.engine.base.Engine.dispose") as mock_engine_dispose: + await hass.async_stop() + + assert mock_engine_dispose.call_count == 2