Set character set to utf8mb4 when connecting to MySQL or MariaDB databases (#79755)

This commit is contained in:
Erik Montnemery 2022-10-11 14:01:46 +02:00 committed by GitHub
parent bcbf99243d
commit 9aa6043255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 156 additions and 9 deletions

View File

@ -69,7 +69,10 @@ ALLOW_IN_MEMORY_DB = False
def validate_db_url(db_url: str) -> Any:
"""Validate database URL."""
# Don't allow on-memory sqlite databases
if (db_url == SQLITE_URL_PREFIX or ":memory:" in db_url) and not ALLOW_IN_MEMORY_DB:
if (
db_url == SQLITE_URL_PREFIX
or (db_url.startswith(SQLITE_URL_PREFIX) and ":memory:" in db_url)
) and not ALLOW_IN_MEMORY_DB:
raise vol.Invalid("In-memory SQLite database is not supported")
return db_url

View File

@ -8,7 +8,10 @@ from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-im
DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"
MARIADB_URL_PREFIX = "mariadb://"
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"
MYSQLDB_URL_PREFIX = "mysql://"
MYSQLDB_PYMYSQL_URL_PREFIX = "mysql+pymysql://"
DOMAIN = "recorder"
CONF_DB_INTEGRITY_CHECK = "db_integrity_check"

View File

@ -46,7 +46,10 @@ from .const import (
DB_WORKER_PREFIX,
DOMAIN,
KEEPALIVE_TIME,
MARIADB_PYMYSQL_URL_PREFIX,
MARIADB_URL_PREFIX,
MAX_QUEUE_BACKLOG,
MYSQLDB_PYMYSQL_URL_PREFIX,
MYSQLDB_URL_PREFIX,
SQLITE_URL_PREFIX,
SupportedDialect,
@ -1114,14 +1117,23 @@ class Recorder(threading.Thread):
kwargs["pool_reset_on_return"] = None
elif self.db_url.startswith(SQLITE_URL_PREFIX):
kwargs["poolclass"] = RecorderPool
elif self.db_url.startswith(MYSQLDB_URL_PREFIX):
# If they have configured MySQLDB but don't have
# the MySQLDB module installed this will throw
# an ImportError which we suppress here since
# sqlalchemy will give them a better error when
# it tried to import it below.
with contextlib.suppress(ImportError):
kwargs["connect_args"] = {"conv": build_mysqldb_conv()}
elif self.db_url.startswith(
(
MARIADB_URL_PREFIX,
MARIADB_PYMYSQL_URL_PREFIX,
MYSQLDB_URL_PREFIX,
MYSQLDB_PYMYSQL_URL_PREFIX,
)
):
kwargs["connect_args"] = {"charset": "utf8mb4"}
if self.db_url.startswith((MARIADB_URL_PREFIX, MYSQLDB_URL_PREFIX)):
# If they have configured MySQLDB but don't have
# the MySQLDB module installed this will throw
# an ImportError which we suppress here since
# sqlalchemy will give them a better error when
# it tried to import it below.
with contextlib.suppress(ImportError):
kwargs["connect_args"]["conv"] = build_mysqldb_conv()
# Disable extended logging for non SQLite databases
if not self.db_url.startswith(SQLITE_URL_PREFIX):

View File

@ -17,12 +17,15 @@ from homeassistant.components.recorder import (
CONF_AUTO_PURGE,
CONF_AUTO_REPACK,
CONF_COMMIT_INTERVAL,
CONF_DB_MAX_RETRIES,
CONF_DB_RETRY_WAIT,
CONF_DB_URL,
CONFIG_SCHEMA,
DOMAIN,
SQLITE_URL_PREFIX,
Recorder,
get_instance,
pool,
)
from homeassistant.components.recorder.const import KEEPALIVE_TIME
from homeassistant.components.recorder.db_schema import (
@ -1626,3 +1629,129 @@ async def test_disable_echo(hass, db_url, echo, caplog):
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: db_url}})
create_engine_mock.assert_called_once()
assert create_engine_mock.mock_calls[0][2].get("echo") == echo
@pytest.mark.parametrize(
"config_url, connect_args",
(
(
"mariadb://user:password@SERVER_IP/DB_NAME",
{"charset": "utf8mb4"},
),
(
"mariadb+pymysql://user:password@SERVER_IP/DB_NAME",
{"charset": "utf8mb4"},
),
(
"mysql://user:password@SERVER_IP/DB_NAME",
{"charset": "utf8mb4"},
),
(
"mysql+pymysql://user:password@SERVER_IP/DB_NAME",
{"charset": "utf8mb4"},
),
(
"mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4",
{"charset": "utf8mb4"},
),
(
"mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other",
{"charset": "utf8mb4"},
),
(
"postgresql://blabla",
None,
),
(
"sqlite://blabla",
None,
),
),
)
async def test_mysql_missing_utf8mb4(hass, config_url, connect_args):
"""Test recorder fails to setup if charset=utf8mb4 is missing from db_url."""
recorder_helper.async_initialize_recorder(hass)
class MockEvent:
def listen(self, _, _2, callback):
callback(None, None)
mock_event = MockEvent()
with patch(
"homeassistant.components.recorder.core.create_engine"
) as create_engine_mock, patch(
"homeassistant.components.recorder.core.sqlalchemy_event", mock_event
):
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: config_url}})
create_engine_mock.assert_called_once()
assert create_engine_mock.mock_calls[0][2].get("connect_args") == connect_args
@pytest.mark.parametrize(
"config_url",
(
"mysql://user:password@SERVER_IP/DB_NAME",
"mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4",
"mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other",
),
)
async def test_connect_args_priority(hass, config_url):
"""Test connect_args has priority over URL query."""
connect_params = []
recorder_helper.async_initialize_recorder(hass)
class MockDialect:
"""Non functioning dialect, good enough that SQLAlchemy tries connecting."""
__bases__ = []
_has_events = False
def __init__(*args, **kwargs):
...
def connect(self, *args, **params):
nonlocal connect_params
connect_params.append(params)
return True
def create_connect_args(self, url):
return ([], {"charset": "invalid"})
@classmethod
def dbapi(cls):
...
def engine_created(*args):
...
def get_dialect_pool_class(self, *args):
return pool.RecorderPool
def initialize(*args):
...
def on_connect_url(self, url):
return False
class MockEntrypoint:
def engine_created(*_):
...
def get_dialect_cls(*_):
return MockDialect
with patch("sqlalchemy.engine.url.URL._get_entrypoint", MockEntrypoint), patch(
"sqlalchemy.engine.create.util.get_cls_kwargs", return_value=["echo"]
):
await async_setup_component(
hass,
DOMAIN,
{
DOMAIN: {
CONF_DB_URL: config_url,
CONF_DB_MAX_RETRIES: 1,
CONF_DB_RETRY_WAIT: 0,
}
},
)
assert connect_params == [{"charset": "utf8mb4"}]