mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Improve code quality in sql integration (#71705)
This commit is contained in:
parent
ae89a1243a
commit
a746d7c1d7
@ -13,7 +13,7 @@ import voluptuous as vol
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
||||
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
@ -44,24 +44,23 @@ def validate_sql_select(value: str) -> str | None:
|
||||
|
||||
def validate_query(db_url: str, query: str, column: str) -> bool:
|
||||
"""Validate SQL query."""
|
||||
try:
|
||||
engine = sqlalchemy.create_engine(db_url, future=True)
|
||||
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
||||
except SQLAlchemyError as error:
|
||||
raise error
|
||||
|
||||
engine = sqlalchemy.create_engine(db_url, future=True)
|
||||
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
||||
sess: scoped_session = sessmaker()
|
||||
|
||||
try:
|
||||
result: Result = sess.execute(sqlalchemy.text(query))
|
||||
for res in result.mappings():
|
||||
data = res[column]
|
||||
_LOGGER.debug("Return value from query: %s", data)
|
||||
except SQLAlchemyError as error:
|
||||
_LOGGER.debug("Execution error %s", error)
|
||||
if sess:
|
||||
sess.close()
|
||||
raise ValueError(error) from error
|
||||
|
||||
for res in result.mappings():
|
||||
data = res[column]
|
||||
_LOGGER.debug("Return value from query: %s", data)
|
||||
|
||||
if sess:
|
||||
sess.close()
|
||||
|
||||
@ -73,9 +72,6 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
entry: config_entries.ConfigEntry
|
||||
hass: HomeAssistant
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(
|
||||
|
@ -6,6 +6,7 @@ import decimal
|
||||
import logging
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine import Result
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
import voluptuous as vol
|
||||
@ -73,11 +74,11 @@ async def async_setup_platform(
|
||||
for query in config[CONF_QUERIES]:
|
||||
new_config = {
|
||||
CONF_DB_URL: config.get(CONF_DB_URL, default_db_url),
|
||||
CONF_NAME: query.get(CONF_NAME),
|
||||
CONF_QUERY: query.get(CONF_QUERY),
|
||||
CONF_NAME: query[CONF_NAME],
|
||||
CONF_QUERY: query[CONF_QUERY],
|
||||
CONF_UNIT_OF_MEASUREMENT: query.get(CONF_UNIT_OF_MEASUREMENT),
|
||||
CONF_VALUE_TEMPLATE: query.get(CONF_VALUE_TEMPLATE),
|
||||
CONF_COLUMN_NAME: query.get(CONF_COLUMN_NAME),
|
||||
CONF_COLUMN_NAME: query[CONF_COLUMN_NAME],
|
||||
}
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
@ -119,11 +120,10 @@ async def async_setup_entry(
|
||||
|
||||
# MSSQL uses TOP and not LIMIT
|
||||
if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()):
|
||||
query_str = (
|
||||
query_str.replace("SELECT", "SELECT TOP 1")
|
||||
if "mssql" in db_url
|
||||
else query_str.replace(";", " LIMIT 1;")
|
||||
)
|
||||
if "mssql" in db_url:
|
||||
query_str = query_str.upper().replace("SELECT", "SELECT TOP 1")
|
||||
else:
|
||||
query_str = query_str.replace(";", "") + " LIMIT 1;"
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
@ -179,7 +179,7 @@ class SQLSensor(SensorEntity):
|
||||
self._attr_extra_state_attributes = {}
|
||||
sess: scoped_session = self.sessionmaker()
|
||||
try:
|
||||
result = sess.execute(sqlalchemy.text(self._query))
|
||||
result: Result = sess.execute(sqlalchemy.text(self._query))
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.error(
|
||||
"Error executing query %s: %s",
|
||||
@ -188,10 +188,8 @@ class SQLSensor(SensorEntity):
|
||||
)
|
||||
return
|
||||
|
||||
_LOGGER.debug("Result %s, ResultMapping %s", result, result.mappings())
|
||||
|
||||
for res in result.mappings():
|
||||
_LOGGER.debug("result = %s", res.items())
|
||||
_LOGGER.debug("Query %s result in %s", self._query, res.items())
|
||||
data = res[self._column_name]
|
||||
for key, value in res.items():
|
||||
if isinstance(value, decimal.Decimal):
|
||||
|
@ -5,8 +5,7 @@
|
||||
},
|
||||
"error": {
|
||||
"db_url_invalid": "Database URL invalid",
|
||||
"query_invalid": "SQL Query invalid",
|
||||
"value_template_invalid": "Value Template invalid"
|
||||
"query_invalid": "SQL Query invalid"
|
||||
},
|
||||
"step": {
|
||||
"user": {
|
||||
@ -52,8 +51,7 @@
|
||||
},
|
||||
"error": {
|
||||
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
|
||||
"query_invalid": "[%key:component::sql::config::error::query_invalid%]",
|
||||
"value_template_invalid": "[%key:component::sql::config::error::value_template_invalid%]"
|
||||
"query_invalid": "[%key:component::sql::config::error::query_invalid%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,8 +5,7 @@
|
||||
},
|
||||
"error": {
|
||||
"db_url_invalid": "Database URL invalid",
|
||||
"query_invalid": "SQL Query invalid",
|
||||
"value_template_invalid": "Value Template invalid"
|
||||
"query_invalid": "SQL Query invalid"
|
||||
},
|
||||
"step": {
|
||||
"user": {
|
||||
@ -32,8 +31,7 @@
|
||||
"options": {
|
||||
"error": {
|
||||
"db_url_invalid": "Database URL invalid",
|
||||
"query_invalid": "SQL Query invalid",
|
||||
"value_template_invalid": "Value Template invalid"
|
||||
"query_invalid": "SQL Query invalid"
|
||||
},
|
||||
"step": {
|
||||
"init": {
|
||||
|
@ -42,7 +42,6 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
ENTRY_CONFIG,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
print(ENTRY_CONFIG)
|
||||
|
||||
assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
|
||||
assert result2["title"] == "Get Value"
|
||||
|
@ -115,7 +115,30 @@ async def test_query_no_value(
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
text = "SELECT 5 as value where 1=2 returned no results"
|
||||
text = "SELECT 5 as value where 1=2 LIMIT 1; returned no results"
|
||||
assert text in caplog.text
|
||||
|
||||
|
||||
async def test_query_mssql_no_result(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test the SQL sensor with a query that returns no value."""
|
||||
config = {
|
||||
"db_url": "mssql://",
|
||||
"query": "SELECT 5 as value where 1=2",
|
||||
"column": "value",
|
||||
"name": "count_tables",
|
||||
}
|
||||
with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch(
|
||||
"homeassistant.components.sql.sensor.sqlalchemy.text",
|
||||
return_value="SELECT TOP 1 5 as value where 1=2",
|
||||
):
|
||||
await init_integration(hass, config)
|
||||
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
text = "SELECT TOP 1 5 AS VALUE WHERE 1=2 returned no results"
|
||||
assert text in caplog.text
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user