From 8da150bd71e9e134d908a9e575abb251d1a4ad26 Mon Sep 17 00:00:00 2001 From: G Johansson Date: Sat, 12 Feb 2022 15:13:01 +0100 Subject: [PATCH] Improve code quality sql (#65321) --- homeassistant/components/sql/sensor.py | 110 ++++++++++------------ tests/components/sql/test_sensor.py | 121 ++++++++++++++++++++++++- 2 files changed, 167 insertions(+), 64 deletions(-) diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index b9e3b9ce81d..1c8e87051be 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -1,8 +1,7 @@ """Sensor from an SQL Query.""" from __future__ import annotations -import datetime -import decimal +from datetime import date import logging import re @@ -11,11 +10,15 @@ from sqlalchemy.orm import scoped_session, sessionmaker import voluptuous as vol from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL -from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity +from homeassistant.components.sensor import ( + PLATFORM_SCHEMA as PARENT_PLATFORM_SCHEMA, + SensorEntity, +) from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType _LOGGER = logging.getLogger(__name__) @@ -27,12 +30,12 @@ CONF_QUERY = "query" DB_URL_RE = re.compile("//.*:.*@") -def redact_credentials(data): +def redact_credentials(data: str) -> str: """Redact credentials from string data.""" return DB_URL_RE.sub("//****:****@", data) -def validate_sql_select(value): +def validate_sql_select(value: str) -> str: """Validate that value is a SQL SELECT query.""" if not value.lstrip().lower().startswith("select"): raise vol.Invalid("Only SELECT queries allowed") @@ -49,7 +52,7 @@ _QUERY_SCHEME = vol.Schema( } ) -PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( +PLATFORM_SCHEMA = PARENT_PLATFORM_SCHEMA.extend( {vol.Required(CONF_QUERIES): [_QUERY_SCHEME], vol.Optional(CONF_DB_URL): cv.string} ) @@ -64,7 +67,7 @@ def setup_platform( if not (db_url := config.get(CONF_DB_URL)): db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE)) - sess = None + sess: scoped_session | None = None try: engine = sqlalchemy.create_engine(db_url) sessmaker = scoped_session(sessionmaker(bind=engine)) @@ -87,11 +90,11 @@ def setup_platform( queries = [] for query in config[CONF_QUERIES]: - name = query.get(CONF_NAME) - query_str = query.get(CONF_QUERY) - unit = query.get(CONF_UNIT_OF_MEASUREMENT) - value_template = query.get(CONF_VALUE_TEMPLATE) - column_name = query.get(CONF_COLUMN_NAME) + name: str = query[CONF_NAME] + query_str: str = query[CONF_QUERY] + unit: str | None = query.get(CONF_UNIT_OF_MEASUREMENT) + value_template: Template | None = query.get(CONF_VALUE_TEMPLATE) + column_name: str = query[CONF_COLUMN_NAME] if value_template is not None: value_template.hass = hass @@ -115,60 +118,32 @@ def setup_platform( class SQLSensor(SensorEntity): """Representation of an SQL sensor.""" - def __init__(self, name, sessmaker, query, column, unit, value_template): + def __init__( + self, + name: str, + sessmaker: scoped_session, + query: str, + column: str, + unit: str | None, + value_template: Template | None, + ) -> None: """Initialize the SQL sensor.""" - self._name = name + self._attr_name = name self._query = query - self._unit_of_measurement = unit + self._attr_native_unit_of_measurement = unit self._template = value_template self._column_name = column self.sessionmaker = sessmaker - self._state = None - self._attributes = None + self._attr_extra_state_attributes = {} - @property - def name(self): - """Return the name of the query.""" - return self._name - - @property - def native_value(self): - """Return the query's current state.""" - return self._state - - @property - def native_unit_of_measurement(self): - """Return the unit of measurement.""" - return self._unit_of_measurement - - @property - def extra_state_attributes(self): - """Return the state attributes.""" - return self._attributes - - def update(self): + def update(self) -> None: """Retrieve sensor data from the query.""" data = None + self._attr_extra_state_attributes = {} + sess: scoped_session = self.sessionmaker() try: - sess = self.sessionmaker() result = sess.execute(self._query) - self._attributes = {} - - if not result.returns_rows or result.rowcount == 0: - _LOGGER.warning("%s returned no results", self._query) - self._state = None - return - - for res in result.mappings(): - _LOGGER.debug("result = %s", res.items()) - data = res[self._column_name] - for key, value in res.items(): - if isinstance(value, decimal.Decimal): - value = float(value) - if isinstance(value, datetime.date): - value = str(value) - self._attributes[key] = value except sqlalchemy.exc.SQLAlchemyError as err: _LOGGER.error( "Error executing query %s: %s", @@ -176,12 +151,27 @@ class SQLSensor(SensorEntity): redact_credentials(str(err)), ) return - finally: - sess.close() + + _LOGGER.debug("Result %s, ResultMapping %s", result, result.mappings()) + + for res in result.mappings(): + _LOGGER.debug("result = %s", res.items()) + data = res[self._column_name] + for key, value in res.items(): + if isinstance(value, float): + value = float(value) + if isinstance(value, date): + value = value.isoformat() + self._attr_extra_state_attributes[key] = value if data is not None and self._template is not None: - self._state = self._template.async_render_with_possible_json_value( - data, None + self._attr_native_value = ( + self._template.async_render_with_possible_json_value(data, None) ) else: - self._state = data + self._attr_native_value = data + + if not data: + _LOGGER.warning("%s returned no results", self._query) + + sess.close() diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index 629ec464e58..0e543f98a21 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -1,13 +1,27 @@ """The test for the sql sensor platform.""" +import os + import pytest import voluptuous as vol from homeassistant.components.sql.sensor import validate_sql_select from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component +from tests.common import get_test_config_dir -async def test_query(hass): + +@pytest.fixture(autouse=True) +def remove_file(): + """Remove db.""" + yield + file = os.path.join(get_test_config_dir(), "home-assistant_v2.db") + if os.path.isfile(file): + os.remove(file) + + +async def test_query(hass: HomeAssistant) -> None: """Test the SQL sensor.""" config = { "sensor": { @@ -31,7 +45,53 @@ async def test_query(hass): assert state.attributes["value"] == 5 -async def test_query_limit(hass): +async def test_query_no_db(hass: HomeAssistant) -> None: + """Test the SQL sensor.""" + config = { + "sensor": { + "platform": "sql", + "queries": [ + { + "name": "count_tables", + "query": "SELECT 5 as value", + "column": "value", + } + ], + } + } + + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.count_tables") + assert state.state == "5" + + +async def test_query_value_template(hass: HomeAssistant) -> None: + """Test the SQL sensor.""" + config = { + "sensor": { + "platform": "sql", + "db_url": "sqlite://", + "queries": [ + { + "name": "count_tables", + "query": "SELECT 5.01 as value", + "column": "value", + "value_template": "{{ value | int }}", + } + ], + } + } + + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.count_tables") + assert state.state == "5" + + +async def test_query_limit(hass: HomeAssistant) -> None: """Test the SQL sensor with a query containing 'LIMIT' in lowercase.""" config = { "sensor": { @@ -55,7 +115,30 @@ async def test_query_limit(hass): assert state.attributes["value"] == 5 -async def test_invalid_query(hass): +async def test_query_no_value(hass: HomeAssistant) -> None: + """Test the SQL sensor with a query that returns no value.""" + config = { + "sensor": { + "platform": "sql", + "db_url": "sqlite://", + "queries": [ + { + "name": "count_tables", + "query": "SELECT 5 as value where 1=2", + "column": "value", + } + ], + } + } + + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.count_tables") + assert state.state == STATE_UNKNOWN + + +async def test_invalid_query(hass: HomeAssistant) -> None: """Test the SQL sensor for invalid queries.""" with pytest.raises(vol.Invalid): validate_sql_select("DROP TABLE *") @@ -81,6 +164,30 @@ async def test_invalid_query(hass): assert state.state == STATE_UNKNOWN +async def test_value_float_and_date(hass: HomeAssistant) -> None: + """Test the SQL sensor with a query has float as value.""" + config = { + "sensor": { + "platform": "sql", + "db_url": "sqlite://", + "queries": [ + { + "name": "float_value", + "query": "SELECT 5 as value, cast(5.01 as decimal(10,2)) as value2", + "column": "value", + }, + ], + } + } + + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.float_value") + assert state.state == "5" + assert isinstance(state.attributes["value2"], float) + + @pytest.mark.parametrize( "url,expected_patterns,not_expected_patterns", [ @@ -96,7 +203,13 @@ async def test_invalid_query(hass): ), ], ) -async def test_invalid_url(hass, caplog, url, expected_patterns, not_expected_patterns): +async def test_invalid_url( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + url: str, + expected_patterns: str, + not_expected_patterns: str, +): """Test credentials in url is not logged.""" config = { "sensor": {