diff --git a/homeassistant/components/sql/__init__.py b/homeassistant/components/sql/__init__.py index 4658e19932c..a4768165c25 100644 --- a/homeassistant/components/sql/__init__.py +++ b/homeassistant/components/sql/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import sqlparse import voluptuous as vol from homeassistant.components.recorder import CONF_DB_URL, get_instance @@ -38,9 +39,14 @@ _LOGGER = logging.getLogger(__name__) def validate_sql_select(value: str) -> str: """Validate that value is a SQL SELECT query.""" - if not value.lstrip().lower().startswith("select"): + if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: + raise vol.Invalid("Multiple SQL queries are not supported") + if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": + raise vol.Invalid("Invalid SQL query") + if query_type != "SELECT": + _LOGGER.debug("The SQL query %s is of type %s", query, query_type) raise vol.Invalid("Only SELECT queries allowed") - return value + return str(query[0]) QUERY_SCHEMA = vol.Schema( diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index e00b1f8e402..a697bdc51a7 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -6,8 +6,10 @@ from typing import Any import sqlalchemy from sqlalchemy.engine import Result -from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError +from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError from sqlalchemy.orm import Session, scoped_session, sessionmaker +import sqlparse +from sqlparse.exceptions import SQLParseError import voluptuous as vol from homeassistant import config_entries @@ -80,11 +82,16 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema( ).extend(OPTIONS_SCHEMA.schema) -def validate_sql_select(value: str) -> str | None: +def validate_sql_select(value: str) -> str: """Validate that value is a SQL SELECT query.""" - if not value.lstrip().lower().startswith("select"): - raise ValueError("Incorrect Query") - return value + if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: + raise MultipleResultsFound + if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": + raise ValueError + if query_type != "SELECT": + _LOGGER.debug("The SQL query %s is of type %s", query, query_type) + raise SQLParseError + return str(query[0]) def validate_query(db_url: str, query: str, column: str) -> bool: @@ -148,7 +155,7 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): db_url_for_validation = None try: - validate_sql_select(query) + query = validate_sql_select(query) db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( validate_query, db_url_for_validation, query, column @@ -156,9 +163,14 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): except NoSuchColumnError: errors["column"] = "column_invalid" description_placeholders = {"column": column} + except MultipleResultsFound: + errors["query"] = "multiple_queries" except SQLAlchemyError: errors["db_url"] = "db_url_invalid" - except ValueError: + except SQLParseError: + errors["query"] = "query_no_read_only" + except ValueError as err: + _LOGGER.debug("Invalid query: %s", err) errors["query"] = "query_invalid" options = { @@ -209,7 +221,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): name = self.options.get(CONF_NAME, self.config_entry.title) try: - validate_sql_select(query) + query = validate_sql_select(query) db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( validate_query, db_url_for_validation, query, column @@ -217,9 +229,14 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): except NoSuchColumnError: errors["column"] = "column_invalid" description_placeholders = {"column": column} + except MultipleResultsFound: + errors["query"] = "multiple_queries" except SQLAlchemyError: errors["db_url"] = "db_url_invalid" - except ValueError: + except SQLParseError: + errors["query"] = "query_no_read_only" + except ValueError as err: + _LOGGER.debug("Invalid query: %s", err) errors["query"] = "query_invalid" else: recorder_db = get_instance(self.hass).db_url diff --git a/homeassistant/components/sql/manifest.json b/homeassistant/components/sql/manifest.json index c63ba19e0ad..5ebd79b09a5 100644 --- a/homeassistant/components/sql/manifest.json +++ b/homeassistant/components/sql/manifest.json @@ -5,5 +5,5 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/sql", "iot_class": "local_polling", - "requirements": ["SQLAlchemy==2.0.23"] + "requirements": ["SQLAlchemy==2.0.23", "sqlparse==0.4.4"] } diff --git a/homeassistant/components/sql/strings.json b/homeassistant/components/sql/strings.json index b4bb73d4b99..361585b8876 100644 --- a/homeassistant/components/sql/strings.json +++ b/homeassistant/components/sql/strings.json @@ -6,6 +6,8 @@ "error": { "db_url_invalid": "Database URL invalid", "query_invalid": "SQL Query invalid", + "query_no_read_only": "SQL query must be read-only", + "multiple_queries": "Multiple SQL queries are not supported", "column_invalid": "The column `{column}` is not returned by the query" }, "step": { @@ -61,6 +63,8 @@ "error": { "db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]", "query_invalid": "[%key:component::sql::config::error::query_invalid%]", + "query_no_read_only": "[%key:component::sql::config::error::query_no_read_only%]", + "multiple_queries": "[%key:component::sql::config::error::multiple_queries%]", "column_invalid": "[%key:component::sql::config::error::column_invalid%]" } }, diff --git a/requirements_all.txt b/requirements_all.txt index 1a49b477398..7d97a3f1604 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2535,6 +2535,9 @@ spiderpy==1.6.1 # homeassistant.components.spotify spotipy==2.23.0 +# homeassistant.components.sql +sqlparse==0.4.4 + # homeassistant.components.srp_energy srpenergy==1.3.6 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 0a78ae734ee..2b5cecac4c8 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1909,6 +1909,9 @@ spiderpy==1.6.1 # homeassistant.components.spotify spotipy==2.23.0 +# homeassistant.components.sql +sqlparse==0.4.4 + # homeassistant.components.srp_energy srpenergy==1.3.6 diff --git a/tests/components/sql/__init__.py b/tests/components/sql/__init__.py index 6a629f9603d..9cdd026bd3b 100644 --- a/tests/components/sql/__init__.py +++ b/tests/components/sql/__init__.py @@ -46,17 +46,104 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = { ENTRY_CONFIG_INVALID_QUERY = { CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 FROM as value", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_INVALID_QUERY_2 = { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT5 FROM as value", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_INVALID_QUERY_3 = { + CONF_NAME: "Get Value", + CONF_QUERY: ";;", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_INVALID_QUERY_OPT = { + CONF_QUERY: "SELECT 5 FROM as value", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_INVALID_QUERY_2_OPT = { + CONF_QUERY: "SELECT5 FROM as value", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_INVALID_QUERY_3_OPT = { + CONF_QUERY: ";;", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_QUERY_READ_ONLY_CTE = { + CONF_NAME: "Get Value", + CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;", + CONF_COLUMN_NAME: "state", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + +ENTRY_CONFIG_QUERY_NO_READ_ONLY = { + CONF_NAME: "Get Value", + CONF_QUERY: "UPDATE states SET state = 999999 WHERE state_id = 11125", + CONF_COLUMN_NAME: "state", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + +ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE = { + CONF_NAME: "Get Value", + CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;", + CONF_COLUMN_NAME: "size", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + +ENTRY_CONFIG_QUERY_READ_ONLY_CTE_OPT = { + CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;", + CONF_COLUMN_NAME: "state", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + +ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT = { CONF_QUERY: "UPDATE 5 as value", CONF_COLUMN_NAME: "size", CONF_UNIT_OF_MEASUREMENT: "MiB", } -ENTRY_CONFIG_INVALID_QUERY_OPT = { - CONF_QUERY: "UPDATE 5 as value", +ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT = { + CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;", CONF_COLUMN_NAME: "size", CONF_UNIT_OF_MEASUREMENT: "MiB", } + +ENTRY_CONFIG_MULTIPLE_QUERIES = { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;", + CONF_COLUMN_NAME: "state", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + +ENTRY_CONFIG_MULTIPLE_QUERIES_OPT = { + CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;", + CONF_COLUMN_NAME: "state", + CONF_UNIT_OF_MEASUREMENT: "MiB", +} + + ENTRY_CONFIG_INVALID_COLUMN_NAME = { CONF_NAME: "Get Value", CONF_QUERY: "SELECT 5 as value", diff --git a/tests/components/sql/test_config_flow.py b/tests/components/sql/test_config_flow.py index 6517e319fe4..43608d0d32a 100644 --- a/tests/components/sql/test_config_flow.py +++ b/tests/components/sql/test_config_flow.py @@ -17,8 +17,18 @@ from . import ( ENTRY_CONFIG_INVALID_COLUMN_NAME, ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT, ENTRY_CONFIG_INVALID_QUERY, + ENTRY_CONFIG_INVALID_QUERY_2, + ENTRY_CONFIG_INVALID_QUERY_2_OPT, + ENTRY_CONFIG_INVALID_QUERY_3, + ENTRY_CONFIG_INVALID_QUERY_3_OPT, ENTRY_CONFIG_INVALID_QUERY_OPT, + ENTRY_CONFIG_MULTIPLE_QUERIES, + ENTRY_CONFIG_MULTIPLE_QUERIES_OPT, ENTRY_CONFIG_NO_RESULTS, + ENTRY_CONFIG_QUERY_NO_READ_ONLY, + ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE, + ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT, + ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT, ENTRY_CONFIG_WITH_VALUE_TEMPLATE, ) @@ -132,6 +142,56 @@ async def test_flow_fails_invalid_query( "query": "query_invalid", } + result6 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + user_input=ENTRY_CONFIG_INVALID_QUERY_2, + ) + + assert result6["type"] == FlowResultType.FORM + assert result6["errors"] == { + "query": "query_invalid", + } + + result6 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + user_input=ENTRY_CONFIG_INVALID_QUERY_3, + ) + + assert result6["type"] == FlowResultType.FORM + assert result6["errors"] == { + "query": "query_invalid", + } + + result5 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY, + ) + + assert result5["type"] == FlowResultType.FORM + assert result5["errors"] == { + "query": "query_no_read_only", + } + + result6 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE, + ) + + assert result6["type"] == FlowResultType.FORM + assert result6["errors"] == { + "query": "query_no_read_only", + } + + result6 = await hass.config_entries.flow.async_configure( + result4["flow_id"], + user_input=ENTRY_CONFIG_MULTIPLE_QUERIES, + ) + + assert result6["type"] == FlowResultType.FORM + assert result6["errors"] == { + "query": "multiple_queries", + } + result5 = await hass.config_entries.flow.async_configure( result4["flow_id"], user_input=ENTRY_CONFIG_NO_RESULTS, @@ -380,6 +440,56 @@ async def test_options_flow_fails_invalid_query( "query": "query_invalid", } + result3 = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_INVALID_QUERY_2_OPT, + ) + + assert result3["type"] == FlowResultType.FORM + assert result3["errors"] == { + "query": "query_invalid", + } + + result3 = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_INVALID_QUERY_3_OPT, + ) + + assert result3["type"] == FlowResultType.FORM + assert result3["errors"] == { + "query": "query_invalid", + } + + result2 = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT, + ) + + assert result2["type"] == FlowResultType.FORM + assert result2["errors"] == { + "query": "query_no_read_only", + } + + result3 = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT, + ) + + assert result3["type"] == FlowResultType.FORM + assert result3["errors"] == { + "query": "query_no_read_only", + } + + result3 = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_MULTIPLE_QUERIES_OPT, + ) + + assert result3["type"] == FlowResultType.FORM + assert result3["errors"] == { + "query": "multiple_queries", + } + result4 = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ diff --git a/tests/components/sql/test_init.py b/tests/components/sql/test_init.py index 50de8aba7b3..2ae6010e0c5 100644 --- a/tests/components/sql/test_init.py +++ b/tests/components/sql/test_init.py @@ -58,6 +58,32 @@ async def test_invalid_query(hass: HomeAssistant) -> None: with pytest.raises(vol.Invalid): validate_sql_select("DROP TABLE *") + with pytest.raises(vol.Invalid): + validate_sql_select("SELECT5 as value") + + with pytest.raises(vol.Invalid): + validate_sql_select(";;") + + +async def test_query_no_read_only(hass: HomeAssistant) -> None: + """Test query no read only.""" + with pytest.raises(vol.Invalid): + validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125") + + +async def test_query_no_read_only_cte(hass: HomeAssistant) -> None: + """Test query no read only CTE.""" + with pytest.raises(vol.Invalid): + validate_sql_select( + "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;" + ) + + +async def test_multiple_queries(hass: HomeAssistant) -> None: + """Test multiple queries.""" + with pytest.raises(vol.Invalid): + validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;") + async def test_remove_configured_db_url_if_not_needed_when_not_needed( recorder_mock: Recorder, diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index cdc9a8e07a6..9ac22f48312 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -57,6 +57,22 @@ async def test_query_basic(recorder_mock: Recorder, hass: HomeAssistant) -> None assert state.attributes["value"] == 5 +async def test_query_cte(recorder_mock: Recorder, hass: HomeAssistant) -> None: + """Test the SQL sensor with CTE.""" + config = { + "db_url": "sqlite://", + "query": "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;", + "column": "state", + "name": "Select value SQL query CTE", + "unique_id": "very_unique_id", + } + await init_integration(hass, config) + + state = hass.states.get("sensor.select_value_sql_query_cte") + assert state.state == "10" + assert state.attributes["state"] == 10 + + async def test_query_value_template( recorder_mock: Recorder, hass: HomeAssistant ) -> None: