Files
core/tests/components/sql/test_util.py
2025-11-02 18:49:18 +01:00

94 lines
2.5 KiB
Python

"""Test the sql utils."""
from datetime import UTC, date, datetime
from decimal import Decimal
import pytest
import voluptuous as vol
from homeassistant.components.recorder import Recorder, get_instance
from homeassistant.components.sql.util import (
convert_value,
resolve_db_url,
validate_sql_select,
)
from homeassistant.core import HomeAssistant
async def test_resolve_db_url_when_none_configured(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test return recorder db_url if provided db_url is None."""
db_url = None
resolved_url = resolve_db_url(hass, db_url)
assert resolved_url == get_instance(hass).db_url
async def test_resolve_db_url_when_configured(hass: HomeAssistant) -> None:
"""Test return provided db_url if it's set."""
db_url = "mssql://"
resolved_url = resolve_db_url(hass, db_url)
assert resolved_url == db_url
@pytest.mark.parametrize(
("sql_query", "expected_error_message"),
[
(
"DROP TABLE *",
"Only SELECT queries allowed",
),
(
"SELECT5 as value",
"Invalid SQL query",
),
(
";;",
"Invalid SQL query",
),
(
"UPDATE states SET state = 999999 WHERE state_id = 11125",
"Only SELECT queries allowed",
),
(
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
"Only SELECT queries allowed",
),
(
"SELECT 5 as value; UPDATE states SET state = 10;",
"Multiple SQL queries are not supported",
),
],
)
async def test_invalid_sql_queries(
hass: HomeAssistant,
sql_query: str,
expected_error_message: str,
) -> None:
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
with pytest.raises(vol.Invalid, match=expected_error_message):
validate_sql_select(sql_query)
@pytest.mark.parametrize(
("input", "expected_output"),
[
(Decimal("199.99"), 199.99),
(date(2023, 1, 15), "2023-01-15"),
(datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"),
(b"\xde\xad\xbe\xef", "0xdeadbeef"),
("deadbeef", "deadbeef"),
(199.99, 199.99),
(69, 69),
],
)
async def test_value_conversion(
input: Decimal | date | datetime | bytes | str | float,
expected_output: str | float,
) -> None:
"""Test value conversion."""
assert convert_value(input) == expected_output