mirror of
https://github.com/home-assistant/core.git
synced 2025-07-26 06:37:52 +00:00
* Addresses issue #12856 * error -> warning * added edge case and test * uff uff * Added SELECT validation * Improved tests
This commit is contained in:
parent
38af04c6ce
commit
5063464d5e
@ -24,9 +24,17 @@ CONF_QUERIES = 'queries'
|
|||||||
CONF_QUERY = 'query'
|
CONF_QUERY = 'query'
|
||||||
CONF_COLUMN_NAME = 'column'
|
CONF_COLUMN_NAME = 'column'
|
||||||
|
|
||||||
|
|
||||||
|
def validate_sql_select(value):
|
||||||
|
"""Validate that value is a SQL SELECT query."""
|
||||||
|
if not value.lstrip().lower().startswith('select'):
|
||||||
|
raise vol.Invalid('Only SELECT queries allowed')
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
_QUERY_SCHEME = vol.Schema({
|
_QUERY_SCHEME = vol.Schema({
|
||||||
vol.Required(CONF_NAME): cv.string,
|
vol.Required(CONF_NAME): cv.string,
|
||||||
vol.Required(CONF_QUERY): cv.string,
|
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||||||
vol.Required(CONF_COLUMN_NAME): cv.string,
|
vol.Required(CONF_COLUMN_NAME): cv.string,
|
||||||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
||||||
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
|
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
|
||||||
@ -129,15 +137,17 @@ class SQLSensor(Entity):
|
|||||||
finally:
|
finally:
|
||||||
sess.close()
|
sess.close()
|
||||||
|
|
||||||
for res in result:
|
if not result.returns_rows or result.rowcount == 0:
|
||||||
_LOGGER.debug(res.items())
|
_LOGGER.warning("%s returned no results", self._query)
|
||||||
data = res[self._column_name]
|
self._state = None
|
||||||
self._attributes = {k: str(v) for k, v in res.items()}
|
self._attributes = {}
|
||||||
|
|
||||||
if data is None:
|
|
||||||
_LOGGER.error("%s returned no results", self._query)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
for res in result:
|
||||||
|
_LOGGER.debug("result = %s", res.items())
|
||||||
|
data = res[self._column_name]
|
||||||
|
self._attributes = {k: v for k, v in res.items()}
|
||||||
|
|
||||||
if self._template is not None:
|
if self._template is not None:
|
||||||
self._state = self._template.async_render_with_possible_json_value(
|
self._state = self._template.async_render_with_possible_json_value(
|
||||||
data, None)
|
data, None)
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
"""The test for the sql sensor platform."""
|
"""The test for the sql sensor platform."""
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.sensor.sql import validate_sql_select
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
|
|
||||||
from tests.common import get_test_home_assistant
|
from tests.common import get_test_home_assistant
|
||||||
|
|
||||||
@ -35,3 +39,25 @@ class TestSQLSensor(unittest.TestCase):
|
|||||||
|
|
||||||
state = self.hass.states.get('sensor.count_tables')
|
state = self.hass.states.get('sensor.count_tables')
|
||||||
self.assertEqual(state.state, '0')
|
self.assertEqual(state.state, '0')
|
||||||
|
|
||||||
|
def test_invalid_query(self):
|
||||||
|
"""Test the SQL sensor for invalid queries."""
|
||||||
|
with pytest.raises(vol.Invalid):
|
||||||
|
validate_sql_select("DROP TABLE *")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'sensor': {
|
||||||
|
'platform': 'sql',
|
||||||
|
'db_url': 'sqlite://',
|
||||||
|
'queries': [{
|
||||||
|
'name': 'count_tables',
|
||||||
|
'query': 'SELECT * value FROM sqlite_master;',
|
||||||
|
'column': 'value',
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert setup_component(self.hass, 'sensor', config)
|
||||||
|
|
||||||
|
state = self.hass.states.get('sensor.count_tables')
|
||||||
|
self.assertEqual(state.state, STATE_UNKNOWN)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user