mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Set character set to utf8mb4 when connecting to MySQL or MariaDB databases (#79755)
This commit is contained in:
parent
bcbf99243d
commit
9aa6043255
@ -69,7 +69,10 @@ ALLOW_IN_MEMORY_DB = False
|
||||
def validate_db_url(db_url: str) -> Any:
|
||||
"""Validate database URL."""
|
||||
# Don't allow on-memory sqlite databases
|
||||
if (db_url == SQLITE_URL_PREFIX or ":memory:" in db_url) and not ALLOW_IN_MEMORY_DB:
|
||||
if (
|
||||
db_url == SQLITE_URL_PREFIX
|
||||
or (db_url.startswith(SQLITE_URL_PREFIX) and ":memory:" in db_url)
|
||||
) and not ALLOW_IN_MEMORY_DB:
|
||||
raise vol.Invalid("In-memory SQLite database is not supported")
|
||||
|
||||
return db_url
|
||||
|
@ -8,7 +8,10 @@ from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-im
|
||||
|
||||
DATA_INSTANCE = "recorder_instance"
|
||||
SQLITE_URL_PREFIX = "sqlite://"
|
||||
MARIADB_URL_PREFIX = "mariadb://"
|
||||
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"
|
||||
MYSQLDB_URL_PREFIX = "mysql://"
|
||||
MYSQLDB_PYMYSQL_URL_PREFIX = "mysql+pymysql://"
|
||||
DOMAIN = "recorder"
|
||||
|
||||
CONF_DB_INTEGRITY_CHECK = "db_integrity_check"
|
||||
|
@ -46,7 +46,10 @@ from .const import (
|
||||
DB_WORKER_PREFIX,
|
||||
DOMAIN,
|
||||
KEEPALIVE_TIME,
|
||||
MARIADB_PYMYSQL_URL_PREFIX,
|
||||
MARIADB_URL_PREFIX,
|
||||
MAX_QUEUE_BACKLOG,
|
||||
MYSQLDB_PYMYSQL_URL_PREFIX,
|
||||
MYSQLDB_URL_PREFIX,
|
||||
SQLITE_URL_PREFIX,
|
||||
SupportedDialect,
|
||||
@ -1114,14 +1117,23 @@ class Recorder(threading.Thread):
|
||||
kwargs["pool_reset_on_return"] = None
|
||||
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||
kwargs["poolclass"] = RecorderPool
|
||||
elif self.db_url.startswith(MYSQLDB_URL_PREFIX):
|
||||
# If they have configured MySQLDB but don't have
|
||||
# the MySQLDB module installed this will throw
|
||||
# an ImportError which we suppress here since
|
||||
# sqlalchemy will give them a better error when
|
||||
# it tried to import it below.
|
||||
with contextlib.suppress(ImportError):
|
||||
kwargs["connect_args"] = {"conv": build_mysqldb_conv()}
|
||||
elif self.db_url.startswith(
|
||||
(
|
||||
MARIADB_URL_PREFIX,
|
||||
MARIADB_PYMYSQL_URL_PREFIX,
|
||||
MYSQLDB_URL_PREFIX,
|
||||
MYSQLDB_PYMYSQL_URL_PREFIX,
|
||||
)
|
||||
):
|
||||
kwargs["connect_args"] = {"charset": "utf8mb4"}
|
||||
if self.db_url.startswith((MARIADB_URL_PREFIX, MYSQLDB_URL_PREFIX)):
|
||||
# If they have configured MySQLDB but don't have
|
||||
# the MySQLDB module installed this will throw
|
||||
# an ImportError which we suppress here since
|
||||
# sqlalchemy will give them a better error when
|
||||
# it tried to import it below.
|
||||
with contextlib.suppress(ImportError):
|
||||
kwargs["connect_args"]["conv"] = build_mysqldb_conv()
|
||||
|
||||
# Disable extended logging for non SQLite databases
|
||||
if not self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||
|
@ -17,12 +17,15 @@ from homeassistant.components.recorder import (
|
||||
CONF_AUTO_PURGE,
|
||||
CONF_AUTO_REPACK,
|
||||
CONF_COMMIT_INTERVAL,
|
||||
CONF_DB_MAX_RETRIES,
|
||||
CONF_DB_RETRY_WAIT,
|
||||
CONF_DB_URL,
|
||||
CONFIG_SCHEMA,
|
||||
DOMAIN,
|
||||
SQLITE_URL_PREFIX,
|
||||
Recorder,
|
||||
get_instance,
|
||||
pool,
|
||||
)
|
||||
from homeassistant.components.recorder.const import KEEPALIVE_TIME
|
||||
from homeassistant.components.recorder.db_schema import (
|
||||
@ -1626,3 +1629,129 @@ async def test_disable_echo(hass, db_url, echo, caplog):
|
||||
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: db_url}})
|
||||
create_engine_mock.assert_called_once()
|
||||
assert create_engine_mock.mock_calls[0][2].get("echo") == echo
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_url, connect_args",
|
||||
(
|
||||
(
|
||||
"mariadb://user:password@SERVER_IP/DB_NAME",
|
||||
{"charset": "utf8mb4"},
|
||||
),
|
||||
(
|
||||
"mariadb+pymysql://user:password@SERVER_IP/DB_NAME",
|
||||
{"charset": "utf8mb4"},
|
||||
),
|
||||
(
|
||||
"mysql://user:password@SERVER_IP/DB_NAME",
|
||||
{"charset": "utf8mb4"},
|
||||
),
|
||||
(
|
||||
"mysql+pymysql://user:password@SERVER_IP/DB_NAME",
|
||||
{"charset": "utf8mb4"},
|
||||
),
|
||||
(
|
||||
"mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4",
|
||||
{"charset": "utf8mb4"},
|
||||
),
|
||||
(
|
||||
"mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other",
|
||||
{"charset": "utf8mb4"},
|
||||
),
|
||||
(
|
||||
"postgresql://blabla",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"sqlite://blabla",
|
||||
None,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def test_mysql_missing_utf8mb4(hass, config_url, connect_args):
|
||||
"""Test recorder fails to setup if charset=utf8mb4 is missing from db_url."""
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
|
||||
class MockEvent:
|
||||
def listen(self, _, _2, callback):
|
||||
callback(None, None)
|
||||
|
||||
mock_event = MockEvent()
|
||||
with patch(
|
||||
"homeassistant.components.recorder.core.create_engine"
|
||||
) as create_engine_mock, patch(
|
||||
"homeassistant.components.recorder.core.sqlalchemy_event", mock_event
|
||||
):
|
||||
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: config_url}})
|
||||
create_engine_mock.assert_called_once()
|
||||
assert create_engine_mock.mock_calls[0][2].get("connect_args") == connect_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_url",
|
||||
(
|
||||
"mysql://user:password@SERVER_IP/DB_NAME",
|
||||
"mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4",
|
||||
"mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other",
|
||||
),
|
||||
)
|
||||
async def test_connect_args_priority(hass, config_url):
|
||||
"""Test connect_args has priority over URL query."""
|
||||
connect_params = []
|
||||
recorder_helper.async_initialize_recorder(hass)
|
||||
|
||||
class MockDialect:
|
||||
"""Non functioning dialect, good enough that SQLAlchemy tries connecting."""
|
||||
|
||||
__bases__ = []
|
||||
_has_events = False
|
||||
|
||||
def __init__(*args, **kwargs):
|
||||
...
|
||||
|
||||
def connect(self, *args, **params):
|
||||
nonlocal connect_params
|
||||
connect_params.append(params)
|
||||
return True
|
||||
|
||||
def create_connect_args(self, url):
|
||||
return ([], {"charset": "invalid"})
|
||||
|
||||
@classmethod
|
||||
def dbapi(cls):
|
||||
...
|
||||
|
||||
def engine_created(*args):
|
||||
...
|
||||
|
||||
def get_dialect_pool_class(self, *args):
|
||||
return pool.RecorderPool
|
||||
|
||||
def initialize(*args):
|
||||
...
|
||||
|
||||
def on_connect_url(self, url):
|
||||
return False
|
||||
|
||||
class MockEntrypoint:
|
||||
def engine_created(*_):
|
||||
...
|
||||
|
||||
def get_dialect_cls(*_):
|
||||
return MockDialect
|
||||
|
||||
with patch("sqlalchemy.engine.url.URL._get_entrypoint", MockEntrypoint), patch(
|
||||
"sqlalchemy.engine.create.util.get_cls_kwargs", return_value=["echo"]
|
||||
):
|
||||
await async_setup_component(
|
||||
hass,
|
||||
DOMAIN,
|
||||
{
|
||||
DOMAIN: {
|
||||
CONF_DB_URL: config_url,
|
||||
CONF_DB_MAX_RETRIES: 1,
|
||||
CONF_DB_RETRY_WAIT: 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
assert connect_params == [{"charset": "utf8mb4"}]
|
||||
|
Loading…
x
Reference in New Issue
Block a user