mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Add query type validation independent of declaration position for SQL (#105921)
* Add query type validation independent of declaration position * Restore close sess * Separates invalid query and non-read-only query tests * Add more tests * Use the SQLParseError exception for queries that are not read-only * Add handling for multiple SQL queries. * Fix test * Clean ';' at the beginning of the SQL query * Clean ';' at the beginning of the SQL query - init * Query cleaning before storing * Query cleaning before setup sesensor plataform - YAML * Exception when the SQL query type is not detected * Cleaning * Cleaning * Fix typing in tests * Fix typing in tests * Add test for query = ';;' * Update homeassistant/components/sql/__init__.py Co-authored-by: G Johansson <goran.johansson@shiftit.se> * Update homeassistant/components/sql/__init__.py Co-authored-by: G Johansson <goran.johansson@shiftit.se> * Update __init__.py * Update config_flow.py * Clean query before storing --------- Co-authored-by: G Johansson <goran.johansson@shiftit.se>
This commit is contained in:
parent
37707edc47
commit
65c21438a6
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import sqlparse
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import CONF_DB_URL, get_instance
|
||||
@ -38,9 +39,14 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def validate_sql_select(value: str) -> str:
|
||||
"""Validate that value is a SQL SELECT query."""
|
||||
if not value.lstrip().lower().startswith("select"):
|
||||
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
|
||||
raise vol.Invalid("Multiple SQL queries are not supported")
|
||||
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
|
||||
raise vol.Invalid("Invalid SQL query")
|
||||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
|
||||
raise vol.Invalid("Only SELECT queries allowed")
|
||||
return value
|
||||
return str(query[0])
|
||||
|
||||
|
||||
QUERY_SCHEMA = vol.Schema(
|
||||
|
@ -6,8 +6,10 @@ from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine import Result
|
||||
from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError
|
||||
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
import sqlparse
|
||||
from sqlparse.exceptions import SQLParseError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
@ -80,11 +82,16 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema(
|
||||
).extend(OPTIONS_SCHEMA.schema)
|
||||
|
||||
|
||||
def validate_sql_select(value: str) -> str | None:
|
||||
def validate_sql_select(value: str) -> str:
|
||||
"""Validate that value is a SQL SELECT query."""
|
||||
if not value.lstrip().lower().startswith("select"):
|
||||
raise ValueError("Incorrect Query")
|
||||
return value
|
||||
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
|
||||
raise MultipleResultsFound
|
||||
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
|
||||
raise ValueError
|
||||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
|
||||
raise SQLParseError
|
||||
return str(query[0])
|
||||
|
||||
|
||||
def validate_query(db_url: str, query: str, column: str) -> bool:
|
||||
@ -148,7 +155,7 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
db_url_for_validation = None
|
||||
|
||||
try:
|
||||
validate_sql_select(query)
|
||||
query = validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url_for_validation, query, column
|
||||
@ -156,9 +163,14 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
except NoSuchColumnError:
|
||||
errors["column"] = "column_invalid"
|
||||
description_placeholders = {"column": column}
|
||||
except MultipleResultsFound:
|
||||
errors["query"] = "multiple_queries"
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
except ValueError:
|
||||
except SQLParseError:
|
||||
errors["query"] = "query_no_read_only"
|
||||
except ValueError as err:
|
||||
_LOGGER.debug("Invalid query: %s", err)
|
||||
errors["query"] = "query_invalid"
|
||||
|
||||
options = {
|
||||
@ -209,7 +221,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
|
||||
name = self.options.get(CONF_NAME, self.config_entry.title)
|
||||
|
||||
try:
|
||||
validate_sql_select(query)
|
||||
query = validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url_for_validation, query, column
|
||||
@ -217,9 +229,14 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
|
||||
except NoSuchColumnError:
|
||||
errors["column"] = "column_invalid"
|
||||
description_placeholders = {"column": column}
|
||||
except MultipleResultsFound:
|
||||
errors["query"] = "multiple_queries"
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
except ValueError:
|
||||
except SQLParseError:
|
||||
errors["query"] = "query_no_read_only"
|
||||
except ValueError as err:
|
||||
_LOGGER.debug("Invalid query: %s", err)
|
||||
errors["query"] = "query_invalid"
|
||||
else:
|
||||
recorder_db = get_instance(self.hass).db_url
|
||||
|
@ -5,5 +5,5 @@
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/sql",
|
||||
"iot_class": "local_polling",
|
||||
"requirements": ["SQLAlchemy==2.0.23"]
|
||||
"requirements": ["SQLAlchemy==2.0.23", "sqlparse==0.4.4"]
|
||||
}
|
||||
|
@ -6,6 +6,8 @@
|
||||
"error": {
|
||||
"db_url_invalid": "Database URL invalid",
|
||||
"query_invalid": "SQL Query invalid",
|
||||
"query_no_read_only": "SQL query must be read-only",
|
||||
"multiple_queries": "Multiple SQL queries are not supported",
|
||||
"column_invalid": "The column `{column}` is not returned by the query"
|
||||
},
|
||||
"step": {
|
||||
@ -61,6 +63,8 @@
|
||||
"error": {
|
||||
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
|
||||
"query_invalid": "[%key:component::sql::config::error::query_invalid%]",
|
||||
"query_no_read_only": "[%key:component::sql::config::error::query_no_read_only%]",
|
||||
"multiple_queries": "[%key:component::sql::config::error::multiple_queries%]",
|
||||
"column_invalid": "[%key:component::sql::config::error::column_invalid%]"
|
||||
}
|
||||
},
|
||||
|
@ -2535,6 +2535,9 @@ spiderpy==1.6.1
|
||||
# homeassistant.components.spotify
|
||||
spotipy==2.23.0
|
||||
|
||||
# homeassistant.components.sql
|
||||
sqlparse==0.4.4
|
||||
|
||||
# homeassistant.components.srp_energy
|
||||
srpenergy==1.3.6
|
||||
|
||||
|
@ -1909,6 +1909,9 @@ spiderpy==1.6.1
|
||||
# homeassistant.components.spotify
|
||||
spotipy==2.23.0
|
||||
|
||||
# homeassistant.components.sql
|
||||
sqlparse==0.4.4
|
||||
|
||||
# homeassistant.components.srp_energy
|
||||
srpenergy==1.3.6
|
||||
|
||||
|
@ -46,17 +46,104 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_2 = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_3 = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: ";;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_OPT = {
|
||||
CONF_QUERY: "SELECT 5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_2_OPT = {
|
||||
CONF_QUERY: "SELECT5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_3_OPT = {
|
||||
CONF_QUERY: ";;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_QUERY_READ_ONLY_CTE = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "UPDATE states SET state = 999999 WHERE state_id = 11125",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_READ_ONLY_CTE_OPT = {
|
||||
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT = {
|
||||
CONF_QUERY: "UPDATE 5 as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_OPT = {
|
||||
CONF_QUERY: "UPDATE 5 as value",
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT = {
|
||||
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_MULTIPLE_QUERIES = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT = {
|
||||
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_COLUMN_NAME = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
|
@ -17,8 +17,18 @@ from . import (
|
||||
ENTRY_CONFIG_INVALID_COLUMN_NAME,
|
||||
ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT,
|
||||
ENTRY_CONFIG_INVALID_QUERY,
|
||||
ENTRY_CONFIG_INVALID_QUERY_2,
|
||||
ENTRY_CONFIG_INVALID_QUERY_2_OPT,
|
||||
ENTRY_CONFIG_INVALID_QUERY_3,
|
||||
ENTRY_CONFIG_INVALID_QUERY_3_OPT,
|
||||
ENTRY_CONFIG_INVALID_QUERY_OPT,
|
||||
ENTRY_CONFIG_MULTIPLE_QUERIES,
|
||||
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT,
|
||||
ENTRY_CONFIG_NO_RESULTS,
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY,
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
|
||||
ENTRY_CONFIG_WITH_VALUE_TEMPLATE,
|
||||
)
|
||||
|
||||
@ -132,6 +142,56 @@ async def test_flow_fails_invalid_query(
|
||||
"query": "query_invalid",
|
||||
}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
user_input=ENTRY_CONFIG_INVALID_QUERY_2,
|
||||
)
|
||||
|
||||
assert result6["type"] == FlowResultType.FORM
|
||||
assert result6["errors"] == {
|
||||
"query": "query_invalid",
|
||||
}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
user_input=ENTRY_CONFIG_INVALID_QUERY_3,
|
||||
)
|
||||
|
||||
assert result6["type"] == FlowResultType.FORM
|
||||
assert result6["errors"] == {
|
||||
"query": "query_invalid",
|
||||
}
|
||||
|
||||
result5 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY,
|
||||
)
|
||||
|
||||
assert result5["type"] == FlowResultType.FORM
|
||||
assert result5["errors"] == {
|
||||
"query": "query_no_read_only",
|
||||
}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
|
||||
)
|
||||
|
||||
assert result6["type"] == FlowResultType.FORM
|
||||
assert result6["errors"] == {
|
||||
"query": "query_no_read_only",
|
||||
}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
user_input=ENTRY_CONFIG_MULTIPLE_QUERIES,
|
||||
)
|
||||
|
||||
assert result6["type"] == FlowResultType.FORM
|
||||
assert result6["errors"] == {
|
||||
"query": "multiple_queries",
|
||||
}
|
||||
|
||||
result5 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
user_input=ENTRY_CONFIG_NO_RESULTS,
|
||||
@ -380,6 +440,56 @@ async def test_options_flow_fails_invalid_query(
|
||||
"query": "query_invalid",
|
||||
}
|
||||
|
||||
result3 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_INVALID_QUERY_2_OPT,
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {
|
||||
"query": "query_invalid",
|
||||
}
|
||||
|
||||
result3 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_INVALID_QUERY_3_OPT,
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {
|
||||
"query": "query_invalid",
|
||||
}
|
||||
|
||||
result2 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {
|
||||
"query": "query_no_read_only",
|
||||
}
|
||||
|
||||
result3 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {
|
||||
"query": "query_no_read_only",
|
||||
}
|
||||
|
||||
result3 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_MULTIPLE_QUERIES_OPT,
|
||||
)
|
||||
|
||||
assert result3["type"] == FlowResultType.FORM
|
||||
assert result3["errors"] == {
|
||||
"query": "multiple_queries",
|
||||
}
|
||||
|
||||
result4 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
|
@ -58,6 +58,32 @@ async def test_invalid_query(hass: HomeAssistant) -> None:
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("DROP TABLE *")
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("SELECT5 as value")
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select(";;")
|
||||
|
||||
|
||||
async def test_query_no_read_only(hass: HomeAssistant) -> None:
|
||||
"""Test query no read only."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125")
|
||||
|
||||
|
||||
async def test_query_no_read_only_cte(hass: HomeAssistant) -> None:
|
||||
"""Test query no read only CTE."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select(
|
||||
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;"
|
||||
)
|
||||
|
||||
|
||||
async def test_multiple_queries(hass: HomeAssistant) -> None:
|
||||
"""Test multiple queries."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;")
|
||||
|
||||
|
||||
async def test_remove_configured_db_url_if_not_needed_when_not_needed(
|
||||
recorder_mock: Recorder,
|
||||
|
@ -57,6 +57,22 @@ async def test_query_basic(recorder_mock: Recorder, hass: HomeAssistant) -> None
|
||||
assert state.attributes["value"] == 5
|
||||
|
||||
|
||||
async def test_query_cte(recorder_mock: Recorder, hass: HomeAssistant) -> None:
|
||||
"""Test the SQL sensor with CTE."""
|
||||
config = {
|
||||
"db_url": "sqlite://",
|
||||
"query": "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
|
||||
"column": "state",
|
||||
"name": "Select value SQL query CTE",
|
||||
"unique_id": "very_unique_id",
|
||||
}
|
||||
await init_integration(hass, config)
|
||||
|
||||
state = hass.states.get("sensor.select_value_sql_query_cte")
|
||||
assert state.state == "10"
|
||||
assert state.attributes["state"] == 10
|
||||
|
||||
|
||||
async def test_query_value_template(
|
||||
recorder_mock: Recorder, hass: HomeAssistant
|
||||
) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user