mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Set character set to utf8mb4 when connecting to MySQL or MariaDB databases (#79755)
This commit is contained in:
parent
bcbf99243d
commit
9aa6043255
@ -69,7 +69,10 @@ ALLOW_IN_MEMORY_DB = False
|
|||||||
def validate_db_url(db_url: str) -> Any:
|
def validate_db_url(db_url: str) -> Any:
|
||||||
"""Validate database URL."""
|
"""Validate database URL."""
|
||||||
# Don't allow on-memory sqlite databases
|
# Don't allow on-memory sqlite databases
|
||||||
if (db_url == SQLITE_URL_PREFIX or ":memory:" in db_url) and not ALLOW_IN_MEMORY_DB:
|
if (
|
||||||
|
db_url == SQLITE_URL_PREFIX
|
||||||
|
or (db_url.startswith(SQLITE_URL_PREFIX) and ":memory:" in db_url)
|
||||||
|
) and not ALLOW_IN_MEMORY_DB:
|
||||||
raise vol.Invalid("In-memory SQLite database is not supported")
|
raise vol.Invalid("In-memory SQLite database is not supported")
|
||||||
|
|
||||||
return db_url
|
return db_url
|
||||||
|
@ -8,7 +8,10 @@ from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-im
|
|||||||
|
|
||||||
DATA_INSTANCE = "recorder_instance"
|
DATA_INSTANCE = "recorder_instance"
|
||||||
SQLITE_URL_PREFIX = "sqlite://"
|
SQLITE_URL_PREFIX = "sqlite://"
|
||||||
|
MARIADB_URL_PREFIX = "mariadb://"
|
||||||
|
MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://"
|
||||||
MYSQLDB_URL_PREFIX = "mysql://"
|
MYSQLDB_URL_PREFIX = "mysql://"
|
||||||
|
MYSQLDB_PYMYSQL_URL_PREFIX = "mysql+pymysql://"
|
||||||
DOMAIN = "recorder"
|
DOMAIN = "recorder"
|
||||||
|
|
||||||
CONF_DB_INTEGRITY_CHECK = "db_integrity_check"
|
CONF_DB_INTEGRITY_CHECK = "db_integrity_check"
|
||||||
|
@ -46,7 +46,10 @@ from .const import (
|
|||||||
DB_WORKER_PREFIX,
|
DB_WORKER_PREFIX,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
KEEPALIVE_TIME,
|
KEEPALIVE_TIME,
|
||||||
|
MARIADB_PYMYSQL_URL_PREFIX,
|
||||||
|
MARIADB_URL_PREFIX,
|
||||||
MAX_QUEUE_BACKLOG,
|
MAX_QUEUE_BACKLOG,
|
||||||
|
MYSQLDB_PYMYSQL_URL_PREFIX,
|
||||||
MYSQLDB_URL_PREFIX,
|
MYSQLDB_URL_PREFIX,
|
||||||
SQLITE_URL_PREFIX,
|
SQLITE_URL_PREFIX,
|
||||||
SupportedDialect,
|
SupportedDialect,
|
||||||
@ -1114,14 +1117,23 @@ class Recorder(threading.Thread):
|
|||||||
kwargs["pool_reset_on_return"] = None
|
kwargs["pool_reset_on_return"] = None
|
||||||
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||||
kwargs["poolclass"] = RecorderPool
|
kwargs["poolclass"] = RecorderPool
|
||||||
elif self.db_url.startswith(MYSQLDB_URL_PREFIX):
|
elif self.db_url.startswith(
|
||||||
|
(
|
||||||
|
MARIADB_URL_PREFIX,
|
||||||
|
MARIADB_PYMYSQL_URL_PREFIX,
|
||||||
|
MYSQLDB_URL_PREFIX,
|
||||||
|
MYSQLDB_PYMYSQL_URL_PREFIX,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
kwargs["connect_args"] = {"charset": "utf8mb4"}
|
||||||
|
if self.db_url.startswith((MARIADB_URL_PREFIX, MYSQLDB_URL_PREFIX)):
|
||||||
# If they have configured MySQLDB but don't have
|
# If they have configured MySQLDB but don't have
|
||||||
# the MySQLDB module installed this will throw
|
# the MySQLDB module installed this will throw
|
||||||
# an ImportError which we suppress here since
|
# an ImportError which we suppress here since
|
||||||
# sqlalchemy will give them a better error when
|
# sqlalchemy will give them a better error when
|
||||||
# it tried to import it below.
|
# it tried to import it below.
|
||||||
with contextlib.suppress(ImportError):
|
with contextlib.suppress(ImportError):
|
||||||
kwargs["connect_args"] = {"conv": build_mysqldb_conv()}
|
kwargs["connect_args"]["conv"] = build_mysqldb_conv()
|
||||||
|
|
||||||
# Disable extended logging for non SQLite databases
|
# Disable extended logging for non SQLite databases
|
||||||
if not self.db_url.startswith(SQLITE_URL_PREFIX):
|
if not self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||||
|
@ -17,12 +17,15 @@ from homeassistant.components.recorder import (
|
|||||||
CONF_AUTO_PURGE,
|
CONF_AUTO_PURGE,
|
||||||
CONF_AUTO_REPACK,
|
CONF_AUTO_REPACK,
|
||||||
CONF_COMMIT_INTERVAL,
|
CONF_COMMIT_INTERVAL,
|
||||||
|
CONF_DB_MAX_RETRIES,
|
||||||
|
CONF_DB_RETRY_WAIT,
|
||||||
CONF_DB_URL,
|
CONF_DB_URL,
|
||||||
CONFIG_SCHEMA,
|
CONFIG_SCHEMA,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
SQLITE_URL_PREFIX,
|
SQLITE_URL_PREFIX,
|
||||||
Recorder,
|
Recorder,
|
||||||
get_instance,
|
get_instance,
|
||||||
|
pool,
|
||||||
)
|
)
|
||||||
from homeassistant.components.recorder.const import KEEPALIVE_TIME
|
from homeassistant.components.recorder.const import KEEPALIVE_TIME
|
||||||
from homeassistant.components.recorder.db_schema import (
|
from homeassistant.components.recorder.db_schema import (
|
||||||
@ -1626,3 +1629,129 @@ async def test_disable_echo(hass, db_url, echo, caplog):
|
|||||||
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: db_url}})
|
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: db_url}})
|
||||||
create_engine_mock.assert_called_once()
|
create_engine_mock.assert_called_once()
|
||||||
assert create_engine_mock.mock_calls[0][2].get("echo") == echo
|
assert create_engine_mock.mock_calls[0][2].get("echo") == echo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_url, connect_args",
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"mariadb://user:password@SERVER_IP/DB_NAME",
|
||||||
|
{"charset": "utf8mb4"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mariadb+pymysql://user:password@SERVER_IP/DB_NAME",
|
||||||
|
{"charset": "utf8mb4"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mysql://user:password@SERVER_IP/DB_NAME",
|
||||||
|
{"charset": "utf8mb4"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mysql+pymysql://user:password@SERVER_IP/DB_NAME",
|
||||||
|
{"charset": "utf8mb4"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4",
|
||||||
|
{"charset": "utf8mb4"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other",
|
||||||
|
{"charset": "utf8mb4"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"postgresql://blabla",
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sqlite://blabla",
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_mysql_missing_utf8mb4(hass, config_url, connect_args):
|
||||||
|
"""Test recorder fails to setup if charset=utf8mb4 is missing from db_url."""
|
||||||
|
recorder_helper.async_initialize_recorder(hass)
|
||||||
|
|
||||||
|
class MockEvent:
|
||||||
|
def listen(self, _, _2, callback):
|
||||||
|
callback(None, None)
|
||||||
|
|
||||||
|
mock_event = MockEvent()
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.recorder.core.create_engine"
|
||||||
|
) as create_engine_mock, patch(
|
||||||
|
"homeassistant.components.recorder.core.sqlalchemy_event", mock_event
|
||||||
|
):
|
||||||
|
await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: config_url}})
|
||||||
|
create_engine_mock.assert_called_once()
|
||||||
|
assert create_engine_mock.mock_calls[0][2].get("connect_args") == connect_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_url",
|
||||||
|
(
|
||||||
|
"mysql://user:password@SERVER_IP/DB_NAME",
|
||||||
|
"mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4",
|
||||||
|
"mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_connect_args_priority(hass, config_url):
|
||||||
|
"""Test connect_args has priority over URL query."""
|
||||||
|
connect_params = []
|
||||||
|
recorder_helper.async_initialize_recorder(hass)
|
||||||
|
|
||||||
|
class MockDialect:
|
||||||
|
"""Non functioning dialect, good enough that SQLAlchemy tries connecting."""
|
||||||
|
|
||||||
|
__bases__ = []
|
||||||
|
_has_events = False
|
||||||
|
|
||||||
|
def __init__(*args, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
|
def connect(self, *args, **params):
|
||||||
|
nonlocal connect_params
|
||||||
|
connect_params.append(params)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_connect_args(self, url):
|
||||||
|
return ([], {"charset": "invalid"})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dbapi(cls):
|
||||||
|
...
|
||||||
|
|
||||||
|
def engine_created(*args):
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_dialect_pool_class(self, *args):
|
||||||
|
return pool.RecorderPool
|
||||||
|
|
||||||
|
def initialize(*args):
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_connect_url(self, url):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class MockEntrypoint:
|
||||||
|
def engine_created(*_):
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_dialect_cls(*_):
|
||||||
|
return MockDialect
|
||||||
|
|
||||||
|
with patch("sqlalchemy.engine.url.URL._get_entrypoint", MockEntrypoint), patch(
|
||||||
|
"sqlalchemy.engine.create.util.get_cls_kwargs", return_value=["echo"]
|
||||||
|
):
|
||||||
|
await async_setup_component(
|
||||||
|
hass,
|
||||||
|
DOMAIN,
|
||||||
|
{
|
||||||
|
DOMAIN: {
|
||||||
|
CONF_DB_URL: config_url,
|
||||||
|
CONF_DB_MAX_RETRIES: 1,
|
||||||
|
CONF_DB_RETRY_WAIT: 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert connect_params == [{"charset": "utf8mb4"}]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user