Add query type validation independent of declaration position for SQL (#105921)

* Add query type validation independent of declaration position

* Restore close sess

* Separates invalid query and non-read-only query tests

* Add more tests

* Use the SQLParseError exception for queries that are not read-only

* Add handling for multiple SQL queries.

* Fix test

* Clean ';' at the beginning of the SQL query

* Clean ';' at the beginning of the SQL query - init

* Query cleaning before storing

* Query cleaning before setup sesensor plataform - YAML

* Exception when the SQL query type is not detected

* Cleaning

* Cleaning

* Fix typing in tests

* Fix typing in tests

* Add test for query = ';;'

* Update homeassistant/components/sql/__init__.py

Co-authored-by: G Johansson <goran.johansson@shiftit.se>

* Update homeassistant/components/sql/__init__.py

Co-authored-by: G Johansson <goran.johansson@shiftit.se>

* Update __init__.py

* Update config_flow.py

* Clean query before storing

---------

Co-authored-by: G Johansson <goran.johansson@shiftit.se>
This commit is contained in:
dougiteixeira 2023-12-27 13:58:35 -03:00 committed by GitHub
parent 37707edc47
commit 65c21438a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 286 additions and 14 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import sqlparse
import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL, get_instance
@ -38,9 +39,14 @@ _LOGGER = logging.getLogger(__name__)
def validate_sql_select(value: str) -> str:
"""Validate that value is a SQL SELECT query."""
if not value.lstrip().lower().startswith("select"):
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
raise vol.Invalid("Multiple SQL queries are not supported")
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
raise vol.Invalid("Invalid SQL query")
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
raise vol.Invalid("Only SELECT queries allowed")
return value
return str(query[0])
QUERY_SCHEMA = vol.Schema(

View File

@ -6,8 +6,10 @@ from typing import Any
import sqlalchemy
from sqlalchemy.engine import Result
from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker
import sqlparse
from sqlparse.exceptions import SQLParseError
import voluptuous as vol
from homeassistant import config_entries
@ -80,11 +82,16 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema(
).extend(OPTIONS_SCHEMA.schema)
def validate_sql_select(value: str) -> str | None:
def validate_sql_select(value: str) -> str:
"""Validate that value is a SQL SELECT query."""
if not value.lstrip().lower().startswith("select"):
raise ValueError("Incorrect Query")
return value
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
raise MultipleResultsFound
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
raise ValueError
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
raise SQLParseError
return str(query[0])
def validate_query(db_url: str, query: str, column: str) -> bool:
@ -148,7 +155,7 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
db_url_for_validation = None
try:
validate_sql_select(query)
query = validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job(
validate_query, db_url_for_validation, query, column
@ -156,9 +163,14 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except NoSuchColumnError:
errors["column"] = "column_invalid"
description_placeholders = {"column": column}
except MultipleResultsFound:
errors["query"] = "multiple_queries"
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
except ValueError:
except SQLParseError:
errors["query"] = "query_no_read_only"
except ValueError as err:
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid"
options = {
@ -209,7 +221,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
name = self.options.get(CONF_NAME, self.config_entry.title)
try:
validate_sql_select(query)
query = validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job(
validate_query, db_url_for_validation, query, column
@ -217,9 +229,14 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
except NoSuchColumnError:
errors["column"] = "column_invalid"
description_placeholders = {"column": column}
except MultipleResultsFound:
errors["query"] = "multiple_queries"
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
except ValueError:
except SQLParseError:
errors["query"] = "query_no_read_only"
except ValueError as err:
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid"
else:
recorder_db = get_instance(self.hass).db_url

View File

@ -5,5 +5,5 @@
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/sql",
"iot_class": "local_polling",
"requirements": ["SQLAlchemy==2.0.23"]
"requirements": ["SQLAlchemy==2.0.23", "sqlparse==0.4.4"]
}

View File

@ -6,6 +6,8 @@
"error": {
"db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid",
"query_no_read_only": "SQL query must be read-only",
"multiple_queries": "Multiple SQL queries are not supported",
"column_invalid": "The column `{column}` is not returned by the query"
},
"step": {
@ -61,6 +63,8 @@
"error": {
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
"query_invalid": "[%key:component::sql::config::error::query_invalid%]",
"query_no_read_only": "[%key:component::sql::config::error::query_no_read_only%]",
"multiple_queries": "[%key:component::sql::config::error::multiple_queries%]",
"column_invalid": "[%key:component::sql::config::error::column_invalid%]"
}
},

View File

@ -2535,6 +2535,9 @@ spiderpy==1.6.1
# homeassistant.components.spotify
spotipy==2.23.0
# homeassistant.components.sql
sqlparse==0.4.4
# homeassistant.components.srp_energy
srpenergy==1.3.6

View File

@ -1909,6 +1909,9 @@ spiderpy==1.6.1
# homeassistant.components.spotify
spotipy==2.23.0
# homeassistant.components.sql
sqlparse==0.4.4
# homeassistant.components.srp_energy
srpenergy==1.3.6

View File

@ -46,17 +46,104 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
ENTRY_CONFIG_INVALID_QUERY = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_2 = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_3 = {
CONF_NAME: "Get Value",
CONF_QUERY: ";;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_OPT = {
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_2_OPT = {
CONF_QUERY: "SELECT5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_3_OPT = {
CONF_QUERY: ";;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_READ_ONLY_CTE = {
CONF_NAME: "Get Value",
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY = {
CONF_NAME: "Get Value",
CONF_QUERY: "UPDATE states SET state = 999999 WHERE state_id = 11125",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE = {
CONF_NAME: "Get Value",
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_READ_ONLY_CTE_OPT = {
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT = {
CONF_QUERY: "UPDATE 5 as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_QUERY_OPT = {
CONF_QUERY: "UPDATE 5 as value",
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT = {
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_MULTIPLE_QUERIES = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT = {
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_INVALID_COLUMN_NAME = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",

View File

@ -17,8 +17,18 @@ from . import (
ENTRY_CONFIG_INVALID_COLUMN_NAME,
ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT,
ENTRY_CONFIG_INVALID_QUERY,
ENTRY_CONFIG_INVALID_QUERY_2,
ENTRY_CONFIG_INVALID_QUERY_2_OPT,
ENTRY_CONFIG_INVALID_QUERY_3,
ENTRY_CONFIG_INVALID_QUERY_3_OPT,
ENTRY_CONFIG_INVALID_QUERY_OPT,
ENTRY_CONFIG_MULTIPLE_QUERIES,
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT,
ENTRY_CONFIG_NO_RESULTS,
ENTRY_CONFIG_QUERY_NO_READ_ONLY,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
ENTRY_CONFIG_WITH_VALUE_TEMPLATE,
)
@ -132,6 +142,56 @@ async def test_flow_fails_invalid_query(
"query": "query_invalid",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_2,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "query_invalid",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_3,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "query_invalid",
}
result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY,
)
assert result5["type"] == FlowResultType.FORM
assert result5["errors"] == {
"query": "query_no_read_only",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "query_no_read_only",
}
result6 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_MULTIPLE_QUERIES,
)
assert result6["type"] == FlowResultType.FORM
assert result6["errors"] == {
"query": "multiple_queries",
}
result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
user_input=ENTRY_CONFIG_NO_RESULTS,
@ -380,6 +440,56 @@ async def test_options_flow_fails_invalid_query(
"query": "query_invalid",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_2_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "query_invalid",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_INVALID_QUERY_3_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "query_invalid",
}
result2 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {
"query": "query_no_read_only",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "query_no_read_only",
}
result3 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_MULTIPLE_QUERIES_OPT,
)
assert result3["type"] == FlowResultType.FORM
assert result3["errors"] == {
"query": "multiple_queries",
}
result4 = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={

View File

@ -58,6 +58,32 @@ async def test_invalid_query(hass: HomeAssistant) -> None:
with pytest.raises(vol.Invalid):
validate_sql_select("DROP TABLE *")
with pytest.raises(vol.Invalid):
validate_sql_select("SELECT5 as value")
with pytest.raises(vol.Invalid):
validate_sql_select(";;")
async def test_query_no_read_only(hass: HomeAssistant) -> None:
"""Test query no read only."""
with pytest.raises(vol.Invalid):
validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125")
async def test_query_no_read_only_cte(hass: HomeAssistant) -> None:
"""Test query no read only CTE."""
with pytest.raises(vol.Invalid):
validate_sql_select(
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;"
)
async def test_multiple_queries(hass: HomeAssistant) -> None:
"""Test multiple queries."""
with pytest.raises(vol.Invalid):
validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;")
async def test_remove_configured_db_url_if_not_needed_when_not_needed(
recorder_mock: Recorder,

View File

@ -57,6 +57,22 @@ async def test_query_basic(recorder_mock: Recorder, hass: HomeAssistant) -> None
assert state.attributes["value"] == 5
async def test_query_cte(recorder_mock: Recorder, hass: HomeAssistant) -> None:
"""Test the SQL sensor with CTE."""
config = {
"db_url": "sqlite://",
"query": "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
"column": "state",
"name": "Select value SQL query CTE",
"unique_id": "very_unique_id",
}
await init_integration(hass, config)
state = hass.states.get("sensor.select_value_sql_query_cte")
assert state.state == "10"
assert state.attributes["state"] == 10
async def test_query_value_template(
recorder_mock: Recorder, hass: HomeAssistant
) -> None: