mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Default to recorder db for SQL integration (#85436)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
2f4e9c8ef3
commit
afa58b80bd
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.recorder import CONF_DB_URL
|
from homeassistant.components.recorder import CONF_DB_URL, get_instance
|
||||||
from homeassistant.components.sensor import (
|
from homeassistant.components.sensor import (
|
||||||
CONF_STATE_CLASS,
|
CONF_STATE_CLASS,
|
||||||
DEVICE_CLASSES_SCHEMA,
|
DEVICE_CLASSES_SCHEMA,
|
||||||
@ -53,6 +53,18 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_configured_db_url_if_not_needed(
|
||||||
|
hass: HomeAssistant, entry: ConfigEntry
|
||||||
|
) -> None:
|
||||||
|
"""Remove db url from config if it matches recorder database."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
entry,
|
||||||
|
options={
|
||||||
|
key: value for key, value in entry.options.items() if key != CONF_DB_URL
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||||
"""Update listener for options."""
|
"""Update listener for options."""
|
||||||
await hass.config_entries.async_reload(entry.entry_id)
|
await hass.config_entries.async_reload(entry.entry_id)
|
||||||
@ -73,6 +85,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Set up SQL from a config entry."""
|
"""Set up SQL from a config entry."""
|
||||||
|
if entry.options.get(CONF_DB_URL) == get_instance(hass).db_url:
|
||||||
|
remove_configured_db_url_if_not_needed(hass, entry)
|
||||||
|
|
||||||
entry.async_on_unload(entry.add_update_listener(async_update_listener))
|
entry.async_on_unload(entry.add_update_listener(async_update_listener))
|
||||||
|
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
|
@ -11,13 +11,14 @@ from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
from homeassistant.components.recorder import CONF_DB_URL
|
||||||
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
|
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers import selector
|
from homeassistant.helpers import selector
|
||||||
|
|
||||||
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||||
|
from .util import resolve_db_url
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -85,34 +86,37 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Handle the user step."""
|
"""Handle the user step."""
|
||||||
errors = {}
|
errors = {}
|
||||||
db_url_default = DEFAULT_URL.format(
|
|
||||||
hass_config_path=self.hass.config.path(DEFAULT_DB_FILE)
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
db_url = user_input.get(CONF_DB_URL, db_url_default)
|
db_url = user_input.get(CONF_DB_URL)
|
||||||
query = user_input[CONF_QUERY]
|
query = user_input[CONF_QUERY]
|
||||||
column = user_input[CONF_COLUMN_NAME]
|
column = user_input[CONF_COLUMN_NAME]
|
||||||
uom = user_input.get(CONF_UNIT_OF_MEASUREMENT)
|
uom = user_input.get(CONF_UNIT_OF_MEASUREMENT)
|
||||||
value_template = user_input.get(CONF_VALUE_TEMPLATE)
|
value_template = user_input.get(CONF_VALUE_TEMPLATE)
|
||||||
name = user_input[CONF_NAME]
|
name = user_input[CONF_NAME]
|
||||||
|
db_url_for_validation = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validate_sql_select(query)
|
validate_sql_select(query)
|
||||||
|
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||||
await self.hass.async_add_executor_job(
|
await self.hass.async_add_executor_job(
|
||||||
validate_query, db_url, query, column
|
validate_query, db_url_for_validation, query, column
|
||||||
)
|
)
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
errors["db_url"] = "db_url_invalid"
|
errors["db_url"] = "db_url_invalid"
|
||||||
except ValueError:
|
except ValueError:
|
||||||
errors["query"] = "query_invalid"
|
errors["query"] = "query_invalid"
|
||||||
|
|
||||||
|
add_db_url = (
|
||||||
|
{CONF_DB_URL: db_url} if db_url == db_url_for_validation else {}
|
||||||
|
)
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=name,
|
title=name,
|
||||||
data={},
|
data={},
|
||||||
options={
|
options={
|
||||||
CONF_DB_URL: db_url,
|
**add_db_url,
|
||||||
CONF_QUERY: query,
|
CONF_QUERY: query,
|
||||||
CONF_COLUMN_NAME: column,
|
CONF_COLUMN_NAME: column,
|
||||||
CONF_UNIT_OF_MEASUREMENT: uom,
|
CONF_UNIT_OF_MEASUREMENT: uom,
|
||||||
@ -140,32 +144,32 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow):
|
|||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Manage SQL options."""
|
"""Manage SQL options."""
|
||||||
errors = {}
|
errors = {}
|
||||||
db_url_default = DEFAULT_URL.format(
|
|
||||||
hass_config_path=self.hass.config.path(DEFAULT_DB_FILE)
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
db_url = user_input.get(CONF_DB_URL, db_url_default)
|
db_url = user_input.get(CONF_DB_URL)
|
||||||
query = user_input[CONF_QUERY]
|
query = user_input[CONF_QUERY]
|
||||||
column = user_input[CONF_COLUMN_NAME]
|
column = user_input[CONF_COLUMN_NAME]
|
||||||
name = self.entry.options.get(CONF_NAME, self.entry.title)
|
name = self.entry.options.get(CONF_NAME, self.entry.title)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validate_sql_select(query)
|
validate_sql_select(query)
|
||||||
|
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||||
await self.hass.async_add_executor_job(
|
await self.hass.async_add_executor_job(
|
||||||
validate_query, db_url, query, column
|
validate_query, db_url_for_validation, query, column
|
||||||
)
|
)
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
errors["db_url"] = "db_url_invalid"
|
errors["db_url"] = "db_url_invalid"
|
||||||
except ValueError:
|
except ValueError:
|
||||||
errors["query"] = "query_invalid"
|
errors["query"] = "query_invalid"
|
||||||
else:
|
else:
|
||||||
|
new_user_input = user_input
|
||||||
|
if new_user_input.get(CONF_DB_URL) and db_url == db_url_for_validation:
|
||||||
|
new_user_input.pop(CONF_DB_URL)
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title="",
|
title="",
|
||||||
data={
|
data={
|
||||||
CONF_NAME: name,
|
CONF_NAME: name,
|
||||||
CONF_DB_URL: db_url,
|
**new_user_input,
|
||||||
**user_input,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -176,7 +180,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow):
|
|||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_DB_URL,
|
CONF_DB_URL,
|
||||||
description={
|
description={
|
||||||
"suggested_value": self.entry.options[CONF_DB_URL]
|
"suggested_value": self.entry.options.get(CONF_DB_URL)
|
||||||
},
|
},
|
||||||
): selector.TextSelector(),
|
): selector.TextSelector(),
|
||||||
vol.Required(
|
vol.Required(
|
||||||
|
@ -10,7 +10,7 @@ from sqlalchemy.engine import Result
|
|||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||||
|
|
||||||
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
from homeassistant.components.recorder import CONF_DB_URL
|
||||||
from homeassistant.components.sensor import (
|
from homeassistant.components.sensor import (
|
||||||
CONF_STATE_CLASS,
|
CONF_STATE_CLASS,
|
||||||
SensorDeviceClass,
|
SensorDeviceClass,
|
||||||
@ -34,6 +34,7 @@ from homeassistant.helpers.template import Template
|
|||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from .const import CONF_COLUMN_NAME, CONF_QUERY, DB_URL_RE, DOMAIN
|
from .const import CONF_COLUMN_NAME, CONF_QUERY, DB_URL_RE, DOMAIN
|
||||||
|
from .util import resolve_db_url
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -59,7 +60,7 @@ async def async_setup_platform(
|
|||||||
value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE)
|
value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE)
|
||||||
column_name: str = conf[CONF_COLUMN_NAME]
|
column_name: str = conf[CONF_COLUMN_NAME]
|
||||||
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
|
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
|
||||||
db_url: str | None = conf.get(CONF_DB_URL)
|
db_url: str = resolve_db_url(hass, conf.get(CONF_DB_URL))
|
||||||
device_class: SensorDeviceClass | None = conf.get(CONF_DEVICE_CLASS)
|
device_class: SensorDeviceClass | None = conf.get(CONF_DEVICE_CLASS)
|
||||||
state_class: SensorStateClass | None = conf.get(CONF_STATE_CLASS)
|
state_class: SensorStateClass | None = conf.get(CONF_STATE_CLASS)
|
||||||
|
|
||||||
@ -87,7 +88,7 @@ async def async_setup_entry(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the SQL sensor from config entry."""
|
"""Set up the SQL sensor from config entry."""
|
||||||
|
|
||||||
db_url: str = entry.options[CONF_DB_URL]
|
db_url: str = resolve_db_url(hass, entry.options.get(CONF_DB_URL))
|
||||||
name: str = entry.options[CONF_NAME]
|
name: str = entry.options[CONF_NAME]
|
||||||
query_str: str = entry.options[CONF_QUERY]
|
query_str: str = entry.options[CONF_QUERY]
|
||||||
unit: str | None = entry.options.get(CONF_UNIT_OF_MEASUREMENT)
|
unit: str | None = entry.options.get(CONF_UNIT_OF_MEASUREMENT)
|
||||||
@ -128,7 +129,7 @@ async def async_setup_sensor(
|
|||||||
unit: str | None,
|
unit: str | None,
|
||||||
value_template: Template | None,
|
value_template: Template | None,
|
||||||
unique_id: str | None,
|
unique_id: str | None,
|
||||||
db_url: str | None,
|
db_url: str,
|
||||||
yaml: bool,
|
yaml: bool,
|
||||||
device_class: SensorDeviceClass | None,
|
device_class: SensorDeviceClass | None,
|
||||||
state_class: SensorStateClass | None,
|
state_class: SensorStateClass | None,
|
||||||
@ -136,16 +137,12 @@ async def async_setup_sensor(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the SQL sensor."""
|
"""Set up the SQL sensor."""
|
||||||
|
|
||||||
if not db_url:
|
|
||||||
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
|
|
||||||
|
|
||||||
sess: Session | None = None
|
|
||||||
try:
|
try:
|
||||||
engine = sqlalchemy.create_engine(db_url, future=True)
|
engine = sqlalchemy.create_engine(db_url, future=True)
|
||||||
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
||||||
|
|
||||||
# Run a dummy query just to test the db_url
|
# Run a dummy query just to test the db_url
|
||||||
sess = sessmaker()
|
sess: Session = sessmaker()
|
||||||
sess.execute(sqlalchemy.text("SELECT 1;"))
|
sess.execute(sqlalchemy.text("SELECT 1;"))
|
||||||
|
|
||||||
except SQLAlchemyError as err:
|
except SQLAlchemyError as err:
|
||||||
|
12
homeassistant/components/sql/util.py
Normal file
12
homeassistant/components/sql/util.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
"""Utils for sql."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from homeassistant.components.recorder import get_instance
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str:
|
||||||
|
"""Return the db_url provided if not empty, otherwise return the recorder db_url."""
|
||||||
|
if db_url and not db_url.isspace():
|
||||||
|
return db_url
|
||||||
|
return get_instance(hass).db_url
|
@ -23,7 +23,6 @@ from homeassistant.core import HomeAssistant
|
|||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
ENTRY_CONFIG = {
|
ENTRY_CONFIG = {
|
||||||
CONF_DB_URL: "sqlite://",
|
|
||||||
CONF_NAME: "Get Value",
|
CONF_NAME: "Get Value",
|
||||||
CONF_QUERY: "SELECT 5 as value",
|
CONF_QUERY: "SELECT 5 as value",
|
||||||
CONF_COLUMN_NAME: "value",
|
CONF_COLUMN_NAME: "value",
|
||||||
@ -31,7 +30,6 @@ ENTRY_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ENTRY_CONFIG_INVALID_QUERY = {
|
ENTRY_CONFIG_INVALID_QUERY = {
|
||||||
CONF_DB_URL: "sqlite://",
|
|
||||||
CONF_NAME: "Get Value",
|
CONF_NAME: "Get Value",
|
||||||
CONF_QUERY: "UPDATE 5 as value",
|
CONF_QUERY: "UPDATE 5 as value",
|
||||||
CONF_COLUMN_NAME: "size",
|
CONF_COLUMN_NAME: "size",
|
||||||
@ -39,14 +37,12 @@ ENTRY_CONFIG_INVALID_QUERY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ENTRY_CONFIG_INVALID_QUERY_OPT = {
|
ENTRY_CONFIG_INVALID_QUERY_OPT = {
|
||||||
CONF_DB_URL: "sqlite://",
|
|
||||||
CONF_QUERY: "UPDATE 5 as value",
|
CONF_QUERY: "UPDATE 5 as value",
|
||||||
CONF_COLUMN_NAME: "size",
|
CONF_COLUMN_NAME: "size",
|
||||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY_CONFIG_NO_RESULTS = {
|
ENTRY_CONFIG_NO_RESULTS = {
|
||||||
CONF_DB_URL: "sqlite://",
|
|
||||||
CONF_NAME: "Get Value",
|
CONF_NAME: "Get Value",
|
||||||
CONF_QUERY: "SELECT kalle as value from no_table;",
|
CONF_QUERY: "SELECT kalle as value from no_table;",
|
||||||
CONF_COLUMN_NAME: "value",
|
CONF_COLUMN_NAME: "value",
|
||||||
@ -69,7 +65,6 @@ YAML_CONFIG = {
|
|||||||
|
|
||||||
YAML_CONFIG_INVALID = {
|
YAML_CONFIG_INVALID = {
|
||||||
"sql": {
|
"sql": {
|
||||||
CONF_DB_URL: "sqlite://",
|
|
||||||
CONF_QUERY: "SELECT 5 as value",
|
CONF_QUERY: "SELECT 5 as value",
|
||||||
CONF_COLUMN_NAME: "value",
|
CONF_COLUMN_NAME: "value",
|
||||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||||
|
@ -6,7 +6,7 @@ from unittest.mock import patch
|
|||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.recorder import DEFAULT_DB_FILE, DEFAULT_URL, Recorder
|
from homeassistant.components.recorder import Recorder
|
||||||
from homeassistant.components.sql.const import DOMAIN
|
from homeassistant.components.sql.const import DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
@ -43,7 +43,6 @@ async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None:
|
|||||||
assert result2["type"] == FlowResultType.CREATE_ENTRY
|
assert result2["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert result2["title"] == "Get Value"
|
assert result2["title"] == "Get Value"
|
||||||
assert result2["options"] == {
|
assert result2["options"] == {
|
||||||
"db_url": "sqlite://",
|
|
||||||
"name": "Get Value",
|
"name": "Get Value",
|
||||||
"query": "SELECT 5 as value",
|
"query": "SELECT 5 as value",
|
||||||
"column": "value",
|
"column": "value",
|
||||||
@ -113,7 +112,6 @@ async def test_flow_fails_invalid_query(
|
|||||||
assert result5["type"] == FlowResultType.CREATE_ENTRY
|
assert result5["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert result5["title"] == "Get Value"
|
assert result5["title"] == "Get Value"
|
||||||
assert result5["options"] == {
|
assert result5["options"] == {
|
||||||
"db_url": "sqlite://",
|
|
||||||
"name": "Get Value",
|
"name": "Get Value",
|
||||||
"query": "SELECT 5 as value",
|
"query": "SELECT 5 as value",
|
||||||
"column": "value",
|
"column": "value",
|
||||||
@ -163,7 +161,6 @@ async def test_options_flow(recorder_mock: Recorder, hass: HomeAssistant) -> Non
|
|||||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert result["data"] == {
|
assert result["data"] == {
|
||||||
"name": "Get Value",
|
"name": "Get Value",
|
||||||
"db_url": "sqlite://",
|
|
||||||
"query": "SELECT 5 as size",
|
"query": "SELECT 5 as size",
|
||||||
"column": "size",
|
"column": "size",
|
||||||
"unit_of_measurement": "MiB",
|
"unit_of_measurement": "MiB",
|
||||||
@ -215,7 +212,6 @@ async def test_options_flow_name_previously_removed(
|
|||||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert result["data"] == {
|
assert result["data"] == {
|
||||||
"name": "Get Value Title",
|
"name": "Get Value Title",
|
||||||
"db_url": "sqlite://",
|
|
||||||
"query": "SELECT 5 as size",
|
"query": "SELECT 5 as size",
|
||||||
"column": "size",
|
"column": "size",
|
||||||
"unit_of_measurement": "MiB",
|
"unit_of_measurement": "MiB",
|
||||||
@ -316,7 +312,6 @@ async def test_options_flow_fails_invalid_query(
|
|||||||
assert result4["type"] == FlowResultType.CREATE_ENTRY
|
assert result4["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert result4["data"] == {
|
assert result4["data"] == {
|
||||||
"name": "Get Value",
|
"name": "Get Value",
|
||||||
"db_url": "sqlite://",
|
|
||||||
"query": "SELECT 5 as size",
|
"query": "SELECT 5 as size",
|
||||||
"column": "size",
|
"column": "size",
|
||||||
"unit_of_measurement": "MiB",
|
"unit_of_measurement": "MiB",
|
||||||
@ -369,12 +364,9 @@ async def test_options_flow_db_url_empty(
|
|||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
|
|
||||||
|
|
||||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert result["data"] == {
|
assert result["data"] == {
|
||||||
"name": "Get Value",
|
"name": "Get Value",
|
||||||
"db_url": db_url,
|
|
||||||
"query": "SELECT 5 as size",
|
"query": "SELECT 5 as size",
|
||||||
"column": "size",
|
"column": "size",
|
||||||
"unit_of_measurement": "MiB",
|
"unit_of_measurement": "MiB",
|
||||||
|
@ -8,6 +8,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.recorder import Recorder
|
from homeassistant.components.recorder import Recorder
|
||||||
|
from homeassistant.components.recorder.util import get_instance
|
||||||
from homeassistant.components.sql import validate_sql_select
|
from homeassistant.components.sql import validate_sql_select
|
||||||
from homeassistant.components.sql.const import DOMAIN
|
from homeassistant.components.sql.const import DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -56,3 +57,41 @@ async def test_invalid_query(hass: HomeAssistant) -> None:
|
|||||||
"""Test invalid query."""
|
"""Test invalid query."""
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
validate_sql_select("DROP TABLE *")
|
validate_sql_select("DROP TABLE *")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remove_configured_db_url_if_not_needed_when_not_needed(
|
||||||
|
recorder_mock: Recorder,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test configured db_url is replaced with None if matching the recorder db."""
|
||||||
|
recorder_db_url = get_instance(hass).db_url
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"db_url": recorder_db_url,
|
||||||
|
"query": "SELECT 5 as value",
|
||||||
|
"column": "value",
|
||||||
|
"name": "count_tables",
|
||||||
|
}
|
||||||
|
|
||||||
|
config_entry = await init_integration(hass, config)
|
||||||
|
|
||||||
|
assert config_entry.options.get("db_url") is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remove_configured_db_url_if_not_needed_when_needed(
|
||||||
|
recorder_mock: Recorder,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test configured db_url is not replaced if it differs from the recorder db."""
|
||||||
|
db_url = "mssql://"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"db_url": db_url,
|
||||||
|
"query": "SELECT 5 as value",
|
||||||
|
"column": "value",
|
||||||
|
"name": "count_tables",
|
||||||
|
}
|
||||||
|
|
||||||
|
config_entry = await init_integration(hass, config)
|
||||||
|
|
||||||
|
assert config_entry.options.get("db_url") == db_url
|
||||||
|
@ -182,6 +182,7 @@ async def test_invalid_url_setup(
|
|||||||
|
|
||||||
|
|
||||||
async def test_invalid_url_on_update(
|
async def test_invalid_url_on_update(
|
||||||
|
recorder_mock: Recorder,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -192,22 +193,9 @@ async def test_invalid_url_on_update(
|
|||||||
"column": "value",
|
"column": "value",
|
||||||
"name": "count_tables",
|
"name": "count_tables",
|
||||||
}
|
}
|
||||||
entry = MockConfigEntry(
|
await init_integration(hass, config)
|
||||||
domain=DOMAIN,
|
|
||||||
source=SOURCE_USER,
|
|
||||||
data={},
|
|
||||||
options=config,
|
|
||||||
entry_id="1",
|
|
||||||
)
|
|
||||||
|
|
||||||
entry.add_to_hass(hass)
|
|
||||||
|
|
||||||
await hass.config_entries.async_setup(entry.entry_id)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.recorder",
|
|
||||||
), patch(
|
|
||||||
"homeassistant.components.sql.sensor.sqlalchemy.engine.cursor.CursorResult",
|
"homeassistant.components.sql.sensor.sqlalchemy.engine.cursor.CursorResult",
|
||||||
side_effect=SQLAlchemyError(
|
side_effect=SQLAlchemyError(
|
||||||
"sqlite://homeassistant:hunter2@homeassistant.local"
|
"sqlite://homeassistant:hunter2@homeassistant.local"
|
||||||
@ -219,7 +207,6 @@ async def test_invalid_url_on_update(
|
|||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert "sqlite://homeassistant:hunter2@homeassistant.local" not in caplog.text
|
|
||||||
assert "sqlite://****:****@homeassistant.local" in caplog.text
|
assert "sqlite://****:****@homeassistant.local" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
25
tests/components/sql/test_util.py
Normal file
25
tests/components/sql/test_util.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
"""Test the sql utils."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from homeassistant.components.recorder import get_instance
|
||||||
|
from homeassistant.components.sql.util import resolve_db_url
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_db_url_when_none_configured(
|
||||||
|
recorder_mock: AsyncMock,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Test return recorder db_url if provided db_url is None."""
|
||||||
|
db_url = None
|
||||||
|
resolved_url = resolve_db_url(hass, db_url)
|
||||||
|
|
||||||
|
assert resolved_url == get_instance(hass).db_url
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resolve_db_url_when_configured(hass: HomeAssistant):
|
||||||
|
"""Test return provided db_url if it's set."""
|
||||||
|
db_url = "mssql://"
|
||||||
|
resolved_url = resolve_db_url(hass, db_url)
|
||||||
|
|
||||||
|
assert resolved_url == db_url
|
Loading…
x
Reference in New Issue
Block a user