diff --git a/homeassistant/components/sql/__init__.py b/homeassistant/components/sql/__init__.py index c0ec2dfab7f..92b640580eb 100644 --- a/homeassistant/components/sql/__init__.py +++ b/homeassistant/components/sql/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import voluptuous as vol -from homeassistant.components.recorder import CONF_DB_URL +from homeassistant.components.recorder import CONF_DB_URL, get_instance from homeassistant.components.sensor import ( CONF_STATE_CLASS, DEVICE_CLASSES_SCHEMA, @@ -53,6 +53,18 @@ CONFIG_SCHEMA = vol.Schema( ) +def remove_configured_db_url_if_not_needed( + hass: HomeAssistant, entry: ConfigEntry +) -> None: + """Remove db url from config if it matches recorder database.""" + hass.config_entries.async_update_entry( + entry, + options={ + key: value for key, value in entry.options.items() if key != CONF_DB_URL + }, + ) + + async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: """Update listener for options.""" await hass.config_entries.async_reload(entry.entry_id) @@ -73,6 +85,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up SQL from a config entry.""" + if entry.options.get(CONF_DB_URL) == get_instance(hass).db_url: + remove_configured_db_url_if_not_needed(hass, entry) + entry.async_on_unload(entry.add_update_listener(async_update_listener)) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index a6b1afe4049..d52f2d10d0d 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -11,13 +11,14 @@ from sqlalchemy.orm import Session, scoped_session, sessionmaker import voluptuous as vol from homeassistant import config_entries -from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL +from homeassistant.components.recorder import CONF_DB_URL from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE from homeassistant.core import callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import selector from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN +from .util import resolve_db_url _LOGGER = logging.getLogger(__name__) @@ -85,34 +86,37 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ) -> FlowResult: """Handle the user step.""" errors = {} - db_url_default = DEFAULT_URL.format( - hass_config_path=self.hass.config.path(DEFAULT_DB_FILE) - ) if user_input is not None: - db_url = user_input.get(CONF_DB_URL, db_url_default) + db_url = user_input.get(CONF_DB_URL) query = user_input[CONF_QUERY] column = user_input[CONF_COLUMN_NAME] uom = user_input.get(CONF_UNIT_OF_MEASUREMENT) value_template = user_input.get(CONF_VALUE_TEMPLATE) name = user_input[CONF_NAME] + db_url_for_validation = None try: validate_sql_select(query) + db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( - validate_query, db_url, query, column + validate_query, db_url_for_validation, query, column ) except SQLAlchemyError: errors["db_url"] = "db_url_invalid" except ValueError: errors["query"] = "query_invalid" + add_db_url = ( + {CONF_DB_URL: db_url} if db_url == db_url_for_validation else {} + ) + if not errors: return self.async_create_entry( title=name, data={}, options={ - CONF_DB_URL: db_url, + **add_db_url, CONF_QUERY: query, CONF_COLUMN_NAME: column, CONF_UNIT_OF_MEASUREMENT: uom, @@ -140,32 +144,32 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow): ) -> FlowResult: """Manage SQL options.""" errors = {} - db_url_default = DEFAULT_URL.format( - hass_config_path=self.hass.config.path(DEFAULT_DB_FILE) - ) if user_input is not None: - db_url = user_input.get(CONF_DB_URL, db_url_default) + db_url = user_input.get(CONF_DB_URL) query = user_input[CONF_QUERY] column = user_input[CONF_COLUMN_NAME] name = self.entry.options.get(CONF_NAME, self.entry.title) try: validate_sql_select(query) + db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( - validate_query, db_url, query, column + validate_query, db_url_for_validation, query, column ) except SQLAlchemyError: errors["db_url"] = "db_url_invalid" except ValueError: errors["query"] = "query_invalid" else: + new_user_input = user_input + if new_user_input.get(CONF_DB_URL) and db_url == db_url_for_validation: + new_user_input.pop(CONF_DB_URL) return self.async_create_entry( title="", data={ CONF_NAME: name, - CONF_DB_URL: db_url, - **user_input, + **new_user_input, }, ) @@ -176,7 +180,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow): vol.Optional( CONF_DB_URL, description={ - "suggested_value": self.entry.options[CONF_DB_URL] + "suggested_value": self.entry.options.get(CONF_DB_URL) }, ): selector.TextSelector(), vol.Required( diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 27cf798db38..26c899d4d3c 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -10,7 +10,7 @@ from sqlalchemy.engine import Result from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session, scoped_session, sessionmaker -from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL +from homeassistant.components.recorder import CONF_DB_URL from homeassistant.components.sensor import ( CONF_STATE_CLASS, SensorDeviceClass, @@ -34,6 +34,7 @@ from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import CONF_COLUMN_NAME, CONF_QUERY, DB_URL_RE, DOMAIN +from .util import resolve_db_url _LOGGER = logging.getLogger(__name__) @@ -59,7 +60,7 @@ async def async_setup_platform( value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE) column_name: str = conf[CONF_COLUMN_NAME] unique_id: str | None = conf.get(CONF_UNIQUE_ID) - db_url: str | None = conf.get(CONF_DB_URL) + db_url: str = resolve_db_url(hass, conf.get(CONF_DB_URL)) device_class: SensorDeviceClass | None = conf.get(CONF_DEVICE_CLASS) state_class: SensorStateClass | None = conf.get(CONF_STATE_CLASS) @@ -87,7 +88,7 @@ async def async_setup_entry( ) -> None: """Set up the SQL sensor from config entry.""" - db_url: str = entry.options[CONF_DB_URL] + db_url: str = resolve_db_url(hass, entry.options.get(CONF_DB_URL)) name: str = entry.options[CONF_NAME] query_str: str = entry.options[CONF_QUERY] unit: str | None = entry.options.get(CONF_UNIT_OF_MEASUREMENT) @@ -128,7 +129,7 @@ async def async_setup_sensor( unit: str | None, value_template: Template | None, unique_id: str | None, - db_url: str | None, + db_url: str, yaml: bool, device_class: SensorDeviceClass | None, state_class: SensorStateClass | None, @@ -136,16 +137,12 @@ async def async_setup_sensor( ) -> None: """Set up the SQL sensor.""" - if not db_url: - db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) - - sess: Session | None = None try: engine = sqlalchemy.create_engine(db_url, future=True) sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) # Run a dummy query just to test the db_url - sess = sessmaker() + sess: Session = sessmaker() sess.execute(sqlalchemy.text("SELECT 1;")) except SQLAlchemyError as err: diff --git a/homeassistant/components/sql/util.py b/homeassistant/components/sql/util.py new file mode 100644 index 00000000000..81d8cd9900c --- /dev/null +++ b/homeassistant/components/sql/util.py @@ -0,0 +1,12 @@ +"""Utils for sql.""" +from __future__ import annotations + +from homeassistant.components.recorder import get_instance +from homeassistant.core import HomeAssistant + + +def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str: + """Return the db_url provided if not empty, otherwise return the recorder db_url.""" + if db_url and not db_url.isspace(): + return db_url + return get_instance(hass).db_url diff --git a/tests/components/sql/__init__.py b/tests/components/sql/__init__.py index f6cfba01e35..ea58d066325 100644 --- a/tests/components/sql/__init__.py +++ b/tests/components/sql/__init__.py @@ -23,7 +23,6 @@ from homeassistant.core import HomeAssistant from tests.common import MockConfigEntry ENTRY_CONFIG = { - CONF_DB_URL: "sqlite://", CONF_NAME: "Get Value", CONF_QUERY: "SELECT 5 as value", CONF_COLUMN_NAME: "value", @@ -31,7 +30,6 @@ ENTRY_CONFIG = { } ENTRY_CONFIG_INVALID_QUERY = { - CONF_DB_URL: "sqlite://", CONF_NAME: "Get Value", CONF_QUERY: "UPDATE 5 as value", CONF_COLUMN_NAME: "size", @@ -39,14 +37,12 @@ ENTRY_CONFIG_INVALID_QUERY = { } ENTRY_CONFIG_INVALID_QUERY_OPT = { - CONF_DB_URL: "sqlite://", CONF_QUERY: "UPDATE 5 as value", CONF_COLUMN_NAME: "size", CONF_UNIT_OF_MEASUREMENT: "MiB", } ENTRY_CONFIG_NO_RESULTS = { - CONF_DB_URL: "sqlite://", CONF_NAME: "Get Value", CONF_QUERY: "SELECT kalle as value from no_table;", CONF_COLUMN_NAME: "value", @@ -69,7 +65,6 @@ YAML_CONFIG = { YAML_CONFIG_INVALID = { "sql": { - CONF_DB_URL: "sqlite://", CONF_QUERY: "SELECT 5 as value", CONF_COLUMN_NAME: "value", CONF_UNIT_OF_MEASUREMENT: "MiB", diff --git a/tests/components/sql/test_config_flow.py b/tests/components/sql/test_config_flow.py index 789fc983890..3213296a479 100644 --- a/tests/components/sql/test_config_flow.py +++ b/tests/components/sql/test_config_flow.py @@ -6,7 +6,7 @@ from unittest.mock import patch from sqlalchemy.exc import SQLAlchemyError from homeassistant import config_entries -from homeassistant.components.recorder import DEFAULT_DB_FILE, DEFAULT_URL, Recorder +from homeassistant.components.recorder import Recorder from homeassistant.components.sql.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -43,7 +43,6 @@ async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None: assert result2["type"] == FlowResultType.CREATE_ENTRY assert result2["title"] == "Get Value" assert result2["options"] == { - "db_url": "sqlite://", "name": "Get Value", "query": "SELECT 5 as value", "column": "value", @@ -113,7 +112,6 @@ async def test_flow_fails_invalid_query( assert result5["type"] == FlowResultType.CREATE_ENTRY assert result5["title"] == "Get Value" assert result5["options"] == { - "db_url": "sqlite://", "name": "Get Value", "query": "SELECT 5 as value", "column": "value", @@ -163,7 +161,6 @@ async def test_options_flow(recorder_mock: Recorder, hass: HomeAssistant) -> Non assert result["type"] == FlowResultType.CREATE_ENTRY assert result["data"] == { "name": "Get Value", - "db_url": "sqlite://", "query": "SELECT 5 as size", "column": "size", "unit_of_measurement": "MiB", @@ -215,7 +212,6 @@ async def test_options_flow_name_previously_removed( assert result["type"] == FlowResultType.CREATE_ENTRY assert result["data"] == { "name": "Get Value Title", - "db_url": "sqlite://", "query": "SELECT 5 as size", "column": "size", "unit_of_measurement": "MiB", @@ -316,7 +312,6 @@ async def test_options_flow_fails_invalid_query( assert result4["type"] == FlowResultType.CREATE_ENTRY assert result4["data"] == { "name": "Get Value", - "db_url": "sqlite://", "query": "SELECT 5 as size", "column": "size", "unit_of_measurement": "MiB", @@ -369,12 +364,9 @@ async def test_options_flow_db_url_empty( ) await hass.async_block_till_done() - db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) - assert result["type"] == FlowResultType.CREATE_ENTRY assert result["data"] == { "name": "Get Value", - "db_url": db_url, "query": "SELECT 5 as size", "column": "size", "unit_of_measurement": "MiB", diff --git a/tests/components/sql/test_init.py b/tests/components/sql/test_init.py index a110f789a93..50de8aba7b3 100644 --- a/tests/components/sql/test_init.py +++ b/tests/components/sql/test_init.py @@ -8,6 +8,7 @@ import voluptuous as vol from homeassistant import config_entries from homeassistant.components.recorder import Recorder +from homeassistant.components.recorder.util import get_instance from homeassistant.components.sql import validate_sql_select from homeassistant.components.sql.const import DOMAIN from homeassistant.core import HomeAssistant @@ -56,3 +57,41 @@ async def test_invalid_query(hass: HomeAssistant) -> None: """Test invalid query.""" with pytest.raises(vol.Invalid): validate_sql_select("DROP TABLE *") + + +async def test_remove_configured_db_url_if_not_needed_when_not_needed( + recorder_mock: Recorder, + hass: HomeAssistant, +) -> None: + """Test configured db_url is replaced with None if matching the recorder db.""" + recorder_db_url = get_instance(hass).db_url + + config = { + "db_url": recorder_db_url, + "query": "SELECT 5 as value", + "column": "value", + "name": "count_tables", + } + + config_entry = await init_integration(hass, config) + + assert config_entry.options.get("db_url") is None + + +async def test_remove_configured_db_url_if_not_needed_when_needed( + recorder_mock: Recorder, + hass: HomeAssistant, +) -> None: + """Test configured db_url is not replaced if it differs from the recorder db.""" + db_url = "mssql://" + + config = { + "db_url": db_url, + "query": "SELECT 5 as value", + "column": "value", + "name": "count_tables", + } + + config_entry = await init_integration(hass, config) + + assert config_entry.options.get("db_url") == db_url diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index bc3143347b5..32e5a778a87 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -182,6 +182,7 @@ async def test_invalid_url_setup( async def test_invalid_url_on_update( + recorder_mock: Recorder, hass: HomeAssistant, caplog: pytest.LogCaptureFixture, ) -> None: @@ -192,22 +193,9 @@ async def test_invalid_url_on_update( "column": "value", "name": "count_tables", } - entry = MockConfigEntry( - domain=DOMAIN, - source=SOURCE_USER, - data={}, - options=config, - entry_id="1", - ) - - entry.add_to_hass(hass) - - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() + await init_integration(hass, config) with patch( - "homeassistant.components.recorder", - ), patch( "homeassistant.components.sql.sensor.sqlalchemy.engine.cursor.CursorResult", side_effect=SQLAlchemyError( "sqlite://homeassistant:hunter2@homeassistant.local" @@ -219,7 +207,6 @@ async def test_invalid_url_on_update( ) await hass.async_block_till_done() - assert "sqlite://homeassistant:hunter2@homeassistant.local" not in caplog.text assert "sqlite://****:****@homeassistant.local" in caplog.text diff --git a/tests/components/sql/test_util.py b/tests/components/sql/test_util.py new file mode 100644 index 00000000000..31adbe076eb --- /dev/null +++ b/tests/components/sql/test_util.py @@ -0,0 +1,25 @@ +"""Test the sql utils.""" +from unittest.mock import AsyncMock + +from homeassistant.components.recorder import get_instance +from homeassistant.components.sql.util import resolve_db_url +from homeassistant.core import HomeAssistant + + +async def test_resolve_db_url_when_none_configured( + recorder_mock: AsyncMock, + hass: HomeAssistant, +): + """Test return recorder db_url if provided db_url is None.""" + db_url = None + resolved_url = resolve_db_url(hass, db_url) + + assert resolved_url == get_instance(hass).db_url + + +async def test_resolve_db_url_when_configured(hass: HomeAssistant): + """Test return provided db_url if it's set.""" + db_url = "mssql://" + resolved_url = resolve_db_url(hass, db_url) + + assert resolved_url == db_url