Improve code quality in sql integration (#71705)

This commit is contained in:
G Johansson 2022-05-13 01:40:00 +02:00 committed by GitHub
parent ae89a1243a
commit a746d7c1d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 34 deletions

View File

@ -13,7 +13,7 @@ import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
from homeassistant.core import HomeAssistant, callback from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import selector from homeassistant.helpers import selector
@ -44,24 +44,23 @@ def validate_sql_select(value: str) -> str | None:
def validate_query(db_url: str, query: str, column: str) -> bool: def validate_query(db_url: str, query: str, column: str) -> bool:
"""Validate SQL query.""" """Validate SQL query."""
try:
engine = sqlalchemy.create_engine(db_url, future=True) engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True)) sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
except SQLAlchemyError as error:
raise error
sess: scoped_session = sessmaker() sess: scoped_session = sessmaker()
try: try:
result: Result = sess.execute(sqlalchemy.text(query)) result: Result = sess.execute(sqlalchemy.text(query))
for res in result.mappings():
data = res[column]
_LOGGER.debug("Return value from query: %s", data)
except SQLAlchemyError as error: except SQLAlchemyError as error:
_LOGGER.debug("Execution error %s", error)
if sess: if sess:
sess.close() sess.close()
raise ValueError(error) from error raise ValueError(error) from error
for res in result.mappings():
data = res[column]
_LOGGER.debug("Return value from query: %s", data)
if sess: if sess:
sess.close() sess.close()
@ -73,9 +72,6 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
entry: config_entries.ConfigEntry
hass: HomeAssistant
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(

View File

@ -6,6 +6,7 @@ import decimal
import logging import logging
import sqlalchemy import sqlalchemy
from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
import voluptuous as vol import voluptuous as vol
@ -73,11 +74,11 @@ async def async_setup_platform(
for query in config[CONF_QUERIES]: for query in config[CONF_QUERIES]:
new_config = { new_config = {
CONF_DB_URL: config.get(CONF_DB_URL, default_db_url), CONF_DB_URL: config.get(CONF_DB_URL, default_db_url),
CONF_NAME: query.get(CONF_NAME), CONF_NAME: query[CONF_NAME],
CONF_QUERY: query.get(CONF_QUERY), CONF_QUERY: query[CONF_QUERY],
CONF_UNIT_OF_MEASUREMENT: query.get(CONF_UNIT_OF_MEASUREMENT), CONF_UNIT_OF_MEASUREMENT: query.get(CONF_UNIT_OF_MEASUREMENT),
CONF_VALUE_TEMPLATE: query.get(CONF_VALUE_TEMPLATE), CONF_VALUE_TEMPLATE: query.get(CONF_VALUE_TEMPLATE),
CONF_COLUMN_NAME: query.get(CONF_COLUMN_NAME), CONF_COLUMN_NAME: query[CONF_COLUMN_NAME],
} }
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
@ -119,11 +120,10 @@ async def async_setup_entry(
# MSSQL uses TOP and not LIMIT # MSSQL uses TOP and not LIMIT
if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()): if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()):
query_str = ( if "mssql" in db_url:
query_str.replace("SELECT", "SELECT TOP 1") query_str = query_str.upper().replace("SELECT", "SELECT TOP 1")
if "mssql" in db_url else:
else query_str.replace(";", " LIMIT 1;") query_str = query_str.replace(";", "") + " LIMIT 1;"
)
async_add_entities( async_add_entities(
[ [
@ -179,7 +179,7 @@ class SQLSensor(SensorEntity):
self._attr_extra_state_attributes = {} self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker() sess: scoped_session = self.sessionmaker()
try: try:
result = sess.execute(sqlalchemy.text(self._query)) result: Result = sess.execute(sqlalchemy.text(self._query))
except SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.error( _LOGGER.error(
"Error executing query %s: %s", "Error executing query %s: %s",
@ -188,10 +188,8 @@ class SQLSensor(SensorEntity):
) )
return return
_LOGGER.debug("Result %s, ResultMapping %s", result, result.mappings())
for res in result.mappings(): for res in result.mappings():
_LOGGER.debug("result = %s", res.items()) _LOGGER.debug("Query %s result in %s", self._query, res.items())
data = res[self._column_name] data = res[self._column_name]
for key, value in res.items(): for key, value in res.items():
if isinstance(value, decimal.Decimal): if isinstance(value, decimal.Decimal):

View File

@ -5,8 +5,7 @@
}, },
"error": { "error": {
"db_url_invalid": "Database URL invalid", "db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid", "query_invalid": "SQL Query invalid"
"value_template_invalid": "Value Template invalid"
}, },
"step": { "step": {
"user": { "user": {
@ -52,8 +51,7 @@
}, },
"error": { "error": {
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]", "db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
"query_invalid": "[%key:component::sql::config::error::query_invalid%]", "query_invalid": "[%key:component::sql::config::error::query_invalid%]"
"value_template_invalid": "[%key:component::sql::config::error::value_template_invalid%]"
} }
} }
} }

View File

@ -5,8 +5,7 @@
}, },
"error": { "error": {
"db_url_invalid": "Database URL invalid", "db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid", "query_invalid": "SQL Query invalid"
"value_template_invalid": "Value Template invalid"
}, },
"step": { "step": {
"user": { "user": {
@ -32,8 +31,7 @@
"options": { "options": {
"error": { "error": {
"db_url_invalid": "Database URL invalid", "db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid", "query_invalid": "SQL Query invalid"
"value_template_invalid": "Value Template invalid"
}, },
"step": { "step": {
"init": { "init": {

View File

@ -42,7 +42,6 @@ async def test_form(hass: HomeAssistant) -> None:
ENTRY_CONFIG, ENTRY_CONFIG,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
print(ENTRY_CONFIG)
assert result2["type"] == RESULT_TYPE_CREATE_ENTRY assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
assert result2["title"] == "Get Value" assert result2["title"] == "Get Value"

View File

@ -115,7 +115,30 @@ async def test_query_no_value(
state = hass.states.get("sensor.count_tables") state = hass.states.get("sensor.count_tables")
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
text = "SELECT 5 as value where 1=2 returned no results" text = "SELECT 5 as value where 1=2 LIMIT 1; returned no results"
assert text in caplog.text
async def test_query_mssql_no_result(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test the SQL sensor with a query that returns no value."""
config = {
"db_url": "mssql://",
"query": "SELECT 5 as value where 1=2",
"column": "value",
"name": "count_tables",
}
with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch(
"homeassistant.components.sql.sensor.sqlalchemy.text",
return_value="SELECT TOP 1 5 as value where 1=2",
):
await init_integration(hass, config)
state = hass.states.get("sensor.count_tables")
assert state.state == STATE_UNKNOWN
text = "SELECT TOP 1 5 AS VALUE WHERE 1=2 returned no results"
assert text in caplog.text assert text in caplog.text