mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 15:17:35 +00:00
Improve code quality in sql integration (#71705)
This commit is contained in:
parent
ae89a1243a
commit
a746d7c1d7
@ -13,7 +13,7 @@ import voluptuous as vol
|
|||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
||||||
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
|
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers import selector
|
from homeassistant.helpers import selector
|
||||||
|
|
||||||
@ -44,24 +44,23 @@ def validate_sql_select(value: str) -> str | None:
|
|||||||
|
|
||||||
def validate_query(db_url: str, query: str, column: str) -> bool:
|
def validate_query(db_url: str, query: str, column: str) -> bool:
|
||||||
"""Validate SQL query."""
|
"""Validate SQL query."""
|
||||||
try:
|
|
||||||
engine = sqlalchemy.create_engine(db_url, future=True)
|
|
||||||
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
|
||||||
except SQLAlchemyError as error:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
|
engine = sqlalchemy.create_engine(db_url, future=True)
|
||||||
|
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
||||||
sess: scoped_session = sessmaker()
|
sess: scoped_session = sessmaker()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result: Result = sess.execute(sqlalchemy.text(query))
|
result: Result = sess.execute(sqlalchemy.text(query))
|
||||||
for res in result.mappings():
|
|
||||||
data = res[column]
|
|
||||||
_LOGGER.debug("Return value from query: %s", data)
|
|
||||||
except SQLAlchemyError as error:
|
except SQLAlchemyError as error:
|
||||||
|
_LOGGER.debug("Execution error %s", error)
|
||||||
if sess:
|
if sess:
|
||||||
sess.close()
|
sess.close()
|
||||||
raise ValueError(error) from error
|
raise ValueError(error) from error
|
||||||
|
|
||||||
|
for res in result.mappings():
|
||||||
|
data = res[column]
|
||||||
|
_LOGGER.debug("Return value from query: %s", data)
|
||||||
|
|
||||||
if sess:
|
if sess:
|
||||||
sess.close()
|
sess.close()
|
||||||
|
|
||||||
@ -73,9 +72,6 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
|
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
|
|
||||||
entry: config_entries.ConfigEntry
|
|
||||||
hass: HomeAssistant
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@callback
|
@callback
|
||||||
def async_get_options_flow(
|
def async_get_options_flow(
|
||||||
|
@ -6,6 +6,7 @@ import decimal
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
from sqlalchemy.engine import Result
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -73,11 +74,11 @@ async def async_setup_platform(
|
|||||||
for query in config[CONF_QUERIES]:
|
for query in config[CONF_QUERIES]:
|
||||||
new_config = {
|
new_config = {
|
||||||
CONF_DB_URL: config.get(CONF_DB_URL, default_db_url),
|
CONF_DB_URL: config.get(CONF_DB_URL, default_db_url),
|
||||||
CONF_NAME: query.get(CONF_NAME),
|
CONF_NAME: query[CONF_NAME],
|
||||||
CONF_QUERY: query.get(CONF_QUERY),
|
CONF_QUERY: query[CONF_QUERY],
|
||||||
CONF_UNIT_OF_MEASUREMENT: query.get(CONF_UNIT_OF_MEASUREMENT),
|
CONF_UNIT_OF_MEASUREMENT: query.get(CONF_UNIT_OF_MEASUREMENT),
|
||||||
CONF_VALUE_TEMPLATE: query.get(CONF_VALUE_TEMPLATE),
|
CONF_VALUE_TEMPLATE: query.get(CONF_VALUE_TEMPLATE),
|
||||||
CONF_COLUMN_NAME: query.get(CONF_COLUMN_NAME),
|
CONF_COLUMN_NAME: query[CONF_COLUMN_NAME],
|
||||||
}
|
}
|
||||||
hass.async_create_task(
|
hass.async_create_task(
|
||||||
hass.config_entries.flow.async_init(
|
hass.config_entries.flow.async_init(
|
||||||
@ -119,11 +120,10 @@ async def async_setup_entry(
|
|||||||
|
|
||||||
# MSSQL uses TOP and not LIMIT
|
# MSSQL uses TOP and not LIMIT
|
||||||
if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()):
|
if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()):
|
||||||
query_str = (
|
if "mssql" in db_url:
|
||||||
query_str.replace("SELECT", "SELECT TOP 1")
|
query_str = query_str.upper().replace("SELECT", "SELECT TOP 1")
|
||||||
if "mssql" in db_url
|
else:
|
||||||
else query_str.replace(";", " LIMIT 1;")
|
query_str = query_str.replace(";", "") + " LIMIT 1;"
|
||||||
)
|
|
||||||
|
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[
|
[
|
||||||
@ -179,7 +179,7 @@ class SQLSensor(SensorEntity):
|
|||||||
self._attr_extra_state_attributes = {}
|
self._attr_extra_state_attributes = {}
|
||||||
sess: scoped_session = self.sessionmaker()
|
sess: scoped_session = self.sessionmaker()
|
||||||
try:
|
try:
|
||||||
result = sess.execute(sqlalchemy.text(self._query))
|
result: Result = sess.execute(sqlalchemy.text(self._query))
|
||||||
except SQLAlchemyError as err:
|
except SQLAlchemyError as err:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"Error executing query %s: %s",
|
"Error executing query %s: %s",
|
||||||
@ -188,10 +188,8 @@ class SQLSensor(SensorEntity):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
_LOGGER.debug("Result %s, ResultMapping %s", result, result.mappings())
|
|
||||||
|
|
||||||
for res in result.mappings():
|
for res in result.mappings():
|
||||||
_LOGGER.debug("result = %s", res.items())
|
_LOGGER.debug("Query %s result in %s", self._query, res.items())
|
||||||
data = res[self._column_name]
|
data = res[self._column_name]
|
||||||
for key, value in res.items():
|
for key, value in res.items():
|
||||||
if isinstance(value, decimal.Decimal):
|
if isinstance(value, decimal.Decimal):
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"db_url_invalid": "Database URL invalid",
|
"db_url_invalid": "Database URL invalid",
|
||||||
"query_invalid": "SQL Query invalid",
|
"query_invalid": "SQL Query invalid"
|
||||||
"value_template_invalid": "Value Template invalid"
|
|
||||||
},
|
},
|
||||||
"step": {
|
"step": {
|
||||||
"user": {
|
"user": {
|
||||||
@ -52,8 +51,7 @@
|
|||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
|
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
|
||||||
"query_invalid": "[%key:component::sql::config::error::query_invalid%]",
|
"query_invalid": "[%key:component::sql::config::error::query_invalid%]"
|
||||||
"value_template_invalid": "[%key:component::sql::config::error::value_template_invalid%]"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,8 +5,7 @@
|
|||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"db_url_invalid": "Database URL invalid",
|
"db_url_invalid": "Database URL invalid",
|
||||||
"query_invalid": "SQL Query invalid",
|
"query_invalid": "SQL Query invalid"
|
||||||
"value_template_invalid": "Value Template invalid"
|
|
||||||
},
|
},
|
||||||
"step": {
|
"step": {
|
||||||
"user": {
|
"user": {
|
||||||
@ -32,8 +31,7 @@
|
|||||||
"options": {
|
"options": {
|
||||||
"error": {
|
"error": {
|
||||||
"db_url_invalid": "Database URL invalid",
|
"db_url_invalid": "Database URL invalid",
|
||||||
"query_invalid": "SQL Query invalid",
|
"query_invalid": "SQL Query invalid"
|
||||||
"value_template_invalid": "Value Template invalid"
|
|
||||||
},
|
},
|
||||||
"step": {
|
"step": {
|
||||||
"init": {
|
"init": {
|
||||||
|
@ -42,7 +42,6 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
ENTRY_CONFIG,
|
ENTRY_CONFIG,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
print(ENTRY_CONFIG)
|
|
||||||
|
|
||||||
assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
|
assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
|
||||||
assert result2["title"] == "Get Value"
|
assert result2["title"] == "Get Value"
|
||||||
|
@ -115,7 +115,30 @@ async def test_query_no_value(
|
|||||||
state = hass.states.get("sensor.count_tables")
|
state = hass.states.get("sensor.count_tables")
|
||||||
assert state.state == STATE_UNKNOWN
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
text = "SELECT 5 as value where 1=2 returned no results"
|
text = "SELECT 5 as value where 1=2 LIMIT 1; returned no results"
|
||||||
|
assert text in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_mssql_no_result(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test the SQL sensor with a query that returns no value."""
|
||||||
|
config = {
|
||||||
|
"db_url": "mssql://",
|
||||||
|
"query": "SELECT 5 as value where 1=2",
|
||||||
|
"column": "value",
|
||||||
|
"name": "count_tables",
|
||||||
|
}
|
||||||
|
with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch(
|
||||||
|
"homeassistant.components.sql.sensor.sqlalchemy.text",
|
||||||
|
return_value="SELECT TOP 1 5 as value where 1=2",
|
||||||
|
):
|
||||||
|
await init_integration(hass, config)
|
||||||
|
|
||||||
|
state = hass.states.get("sensor.count_tables")
|
||||||
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
|
text = "SELECT TOP 1 5 AS VALUE WHERE 1=2 returned no results"
|
||||||
assert text in caplog.text
|
assert text in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user