Add query type validation independent of declaration position for SQL (#105921)

* Add query type validation independent of declaration position

* Restore close sess

* Separates invalid query and non-read-only query tests

* Add more tests

* Use the SQLParseError exception for queries that are not read-only

* Add handling for multiple SQL queries.

* Fix test

* Clean ';' at the beginning of the SQL query

* Clean ';' at the beginning of the SQL query - init

* Query cleaning before storing

* Query cleaning before setup sesensor plataform - YAML

* Exception when the SQL query type is not detected

* Cleaning

* Cleaning

* Fix typing in tests

* Fix typing in tests

* Add test for query = ';;'

* Update homeassistant/components/sql/__init__.py

Co-authored-by: G Johansson <goran.johansson@shiftit.se>

* Update homeassistant/components/sql/__init__.py

Co-authored-by: G Johansson <goran.johansson@shiftit.se>

* Update __init__.py

* Update config_flow.py

* Clean query before storing

---------

Co-authored-by: G Johansson <goran.johansson@shiftit.se>
This commit is contained in:
dougiteixeira 2023-12-27 13:58:35 -03:00 committed by GitHub
parent 37707edc47
commit 65c21438a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 286 additions and 14 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import sqlparse
import voluptuous as vol import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL, get_instance from homeassistant.components.recorder import CONF_DB_URL, get_instance
@ -38,9 +39,14 @@ _LOGGER = logging.getLogger(__name__)
def validate_sql_select(value: str) -> str: def validate_sql_select(value: str) -> str:
"""Validate that value is a SQL SELECT query.""" """Validate that value is a SQL SELECT query."""
if not value.lstrip().lower().startswith("select"): if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
raise vol.Invalid("Multiple SQL queries are not supported")
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
raise vol.Invalid("Invalid SQL query")
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
raise vol.Invalid("Only SELECT queries allowed") raise vol.Invalid("Only SELECT queries allowed")
return value return str(query[0])
QUERY_SCHEMA = vol.Schema( QUERY_SCHEMA = vol.Schema(

View File

@ -6,8 +6,10 @@ from typing import Any
import sqlalchemy import sqlalchemy
from sqlalchemy.engine import Result from sqlalchemy.engine import Result
from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker from sqlalchemy.orm import Session, scoped_session, sessionmaker
import sqlparse
from sqlparse.exceptions import SQLParseError
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
@ -80,11 +82,16 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema(
).extend(OPTIONS_SCHEMA.schema) ).extend(OPTIONS_SCHEMA.schema)
def validate_sql_select(value: str) -> str | None: def validate_sql_select(value: str) -> str:
"""Validate that value is a SQL SELECT query.""" """Validate that value is a SQL SELECT query."""
if not value.lstrip().lower().startswith("select"): if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
raise ValueError("Incorrect Query") raise MultipleResultsFound
return value if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
raise ValueError
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
raise SQLParseError
return str(query[0])
def validate_query(db_url: str, query: str, column: str) -> bool: def validate_query(db_url: str, query: str, column: str) -> bool:
@ -148,7 +155,7 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
db_url_for_validation = None db_url_for_validation = None
try: try:
validate_sql_select(query) query = validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url) db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job( await self.hass.async_add_executor_job(
validate_query, db_url_for_validation, query, column validate_query, db_url_for_validation, query, column
@ -156,9 +163,14 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except NoSuchColumnError: except NoSuchColumnError:
errors["column"] = "column_invalid" errors["column"] = "column_invalid"
description_placeholders = {"column": column} description_placeholders = {"column": column}
except MultipleResultsFound:
errors["query"] = "multiple_queries"
except SQLAlchemyError: except SQLAlchemyError:
errors["db_url"] = "db_url_invalid" errors["db_url"] = "db_url_invalid"
except ValueError: except SQLParseError:
errors["query"] = "query_no_read_only"
except ValueError as err:
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid" errors["query"] = "query_invalid"
options = { options = {
@ -209,7 +221,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
name = self.options.get(CONF_NAME, self.config_entry.title) name = self.options.get(CONF_NAME, self.config_entry.title)
try: try:
validate_sql_select(query) query = validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url) db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job( await self.hass.async_add_executor_job(
validate_query, db_url_for_validation, query, column validate_query, db_url_for_validation, query, column
@ -217,9 +229,14 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
except NoSuchColumnError: except NoSuchColumnError:
errors["column"] = "column_invalid" errors["column"] = "column_invalid"
description_placeholders = {"column": column} description_placeholders = {"column": column}
except MultipleResultsFound:
errors["query"] = "multiple_queries"
except SQLAlchemyError: except SQLAlchemyError:
errors["db_url"] = "db_url_invalid" errors["db_url"] = "db_url_invalid"
except ValueError: except SQLParseError:
errors["query"] = "query_no_read_only"
except ValueError as err:
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid" errors["query"] = "query_invalid"
else: else:
recorder_db = get_instance(self.hass).db_url recorder_db = get_instance(self.hass).db_url

View File

@ -5,5 +5,5 @@
"config_flow": true, "config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/sql", "documentation": "https://www.home-assistant.io/integrations/sql",
"iot_class": "local_polling", "iot_class": "local_polling",
"requirements": ["SQLAlchemy==2.0.23"] "requirements": ["SQLAlchemy==2.0.23", "sqlparse==0.4.4"]
} }

View File

@ -6,6 +6,8 @@
"error": { "error": {
"db_url_invalid": "Database URL invalid", "db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid", "query_invalid": "SQL Query invalid",
"query_no_read_only": "SQL query must be read-only",
"multiple_queries": "Multiple SQL queries are not supported",
"column_invalid": "The column `{column}` is not returned by the query" "column_invalid": "The column `{column}` is not returned by the query"
}, },
"step": { "step": {
@ -61,6 +63,8 @@
"error": { "error": {
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]", "db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
"query_invalid": "[%key:component::sql::config::error::query_invalid%]", "query_invalid": "[%key:component::sql::config::error::query_invalid%]",
"query_no_read_only": "[%key:component::sql::config::error::query_no_read_only%]",
"multiple_queries": "[%key:component::sql::config::error::multiple_queries%]",
"column_invalid": "[%key:component::sql::config::error::column_invalid%]" "column_invalid": "[%key:component::sql::config::error::column_invalid%]"
} }
}, },

View File

@ -2535,6 +2535,9 @@ spiderpy==1.6.1
# homeassistant.components.spotify # homeassistant.components.spotify
spotipy==2.23.0 spotipy==2.23.0
# homeassistant.components.sql
sqlparse==0.4.4
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.6 srpenergy==1.3.6

View File

@ -1909,6 +1909,9 @@ spiderpy==1.6.1
# homeassistant.components.spotify # homeassistant.components.spotify
spotipy==2.23.0 spotipy==2.23.0
# homeassistant.components.sql
sqlparse==0.4.4
# homeassistant.components.srp_energy # homeassistant.components.srp_energy
srpenergy==1.3.6 srpenergy==1.3.6

View File

@ -46,17 +46,104 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
ENTRY_CONFIG_INVALID_QUERY = { ENTRY_CONFIG_INVALID_QUERY = {
CONF_NAME: "Get Value", CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_2 = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_3 = {
CONF_NAME: "Get Value",
CONF_QUERY: ";;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_OPT = {
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_2_OPT = {
CONF_QUERY: "SELECT5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_3_OPT = {
CONF_QUERY: ";;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_READ_ONLY_CTE = {
CONF_NAME: "Get Value",
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY = {
CONF_NAME: "Get Value",
CONF_QUERY: "UPDATE states SET state = 999999 WHERE state_id = 11125",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE = {
CONF_NAME: "Get Value",
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_READ_ONLY_CTE_OPT = {
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT = {
CONF_QUERY: "UPDATE 5 as value", CONF_QUERY: "UPDATE 5 as value",
CONF_COLUMN_NAME: "size", CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB", CONF_UNIT_OF_MEASUREMENT: "MiB",
} }
ENTRY_CONFIG_INVALID_QUERY_OPT = { ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT = {
CONF_QUERY: "UPDATE 5 as value", CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
CONF_COLUMN_NAME: "size", CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB", CONF_UNIT_OF_MEASUREMENT: "MiB",
} }
ENTRY_CONFIG_MULTIPLE_QUERIES = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT = {
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_COLUMN_NAME = { ENTRY_CONFIG_INVALID_COLUMN_NAME = {
CONF_NAME: "Get Value", CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value", CONF_QUERY: "SELECT 5 as value",

View File

@ -17,8 +17,18 @@ from . import (
ENTRY_CONFIG_INVALID_COLUMN_NAME, ENTRY_CONFIG_INVALID_COLUMN_NAME,
ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT, ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT,
ENTRY_CONFIG_INVALID_QUERY, ENTRY_CONFIG_INVALID_QUERY,
ENTRY_CONFIG_INVALID_QUERY_2,
ENTRY_CONFIG_INVALID_QUERY_2_OPT,
ENTRY_CONFIG_INVALID_QUERY_3,
ENTRY_CONFIG_INVALID_QUERY_3_OPT,
ENTRY_CONFIG_INVALID_QUERY_OPT, ENTRY_CONFIG_INVALID_QUERY_OPT,
ENTRY_CONFIG_MULTIPLE_QUERIES,
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT,
ENTRY_CONFIG_NO_RESULTS, ENTRY_CONFIG_NO_RESULTS,
ENTRY_CONFIG_QUERY_NO_READ_ONLY,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
ENTRY_CONFIG_WITH_VALUE_TEMPLATE, ENTRY_CONFIG_WITH_VALUE_TEMPLATE,
) )
@ -132,6 +142,56 @@ async def test_flow_fails_invalid_query(
"query": "query_invalid", "query": "query_invalid",
} }
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_2,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "query_invalid",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_3,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "query_invalid",
}
result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY,
)
assert result5["type"] == FlowResultType.FORM
assert result5["errors"] == {
"query": "query_no_read_only",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "query_no_read_only",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_MULTIPLE_QUERIES,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "multiple_queries",
}
result5 = await hass.config_entries.flow.async_configure( result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"], result4["flow_id"],
user_input=ENTRY_CONFIG_NO_RESULTS, user_input=ENTRY_CONFIG_NO_RESULTS,
@ -380,6 +440,56 @@ async def test_options_flow_fails_invalid_query(
"query": "query_invalid", "query": "query_invalid",
} }
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_2_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "query_invalid",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_3_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "query_invalid",
}
result2 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {
"query": "query_no_read_only",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "query_no_read_only",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_MULTIPLE_QUERIES_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "multiple_queries",
}
result4 = await hass.config_entries.options.async_configure( result4 = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={ user_input={

View File

@ -58,6 +58,32 @@ async def test_invalid_query(hass: HomeAssistant) -> None:
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
validate_sql_select("DROP TABLE *") validate_sql_select("DROP TABLE *")
with pytest.raises(vol.Invalid):
validate_sql_select("SELECT5 as value")
with pytest.raises(vol.Invalid):
validate_sql_select(";;")
async def test_query_no_read_only(hass: HomeAssistant) -> None:
"""Test query no read only."""
with pytest.raises(vol.Invalid):
validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125")
async def test_query_no_read_only_cte(hass: HomeAssistant) -> None:
"""Test query no read only CTE."""
with pytest.raises(vol.Invalid):
validate_sql_select(
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;"
)
async def test_multiple_queries(hass: HomeAssistant) -> None:
"""Test multiple queries."""
with pytest.raises(vol.Invalid):
validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;")
async def test_remove_configured_db_url_if_not_needed_when_not_needed( async def test_remove_configured_db_url_if_not_needed_when_not_needed(
recorder_mock: Recorder, recorder_mock: Recorder,

View File

@ -57,6 +57,22 @@ async def test_query_basic(recorder_mock: Recorder, hass: HomeAssistant) -> None
assert state.attributes["value"] == 5 assert state.attributes["value"] == 5
async def test_query_cte(recorder_mock: Recorder, hass: HomeAssistant) -> None:
"""Test the SQL sensor with CTE."""
config = {
"db_url": "sqlite://",
"query": "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
"column": "state",
"name": "Select value SQL query CTE",
"unique_id": "very_unique_id",
}
await init_integration(hass, config)
state = hass.states.get("sensor.select_value_sql_query_cte")
assert state.state == "10"
assert state.attributes["state"] == 10
async def test_query_value_template( async def test_query_value_template(
recorder_mock: Recorder, hass: HomeAssistant recorder_mock: Recorder, hass: HomeAssistant
) -> None: ) -> None: