diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index f9ed5f59333..0d4bfe8e59b 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -69,7 +69,10 @@ ALLOW_IN_MEMORY_DB = False def validate_db_url(db_url: str) -> Any: """Validate database URL.""" # Don't allow on-memory sqlite databases - if (db_url == SQLITE_URL_PREFIX or ":memory:" in db_url) and not ALLOW_IN_MEMORY_DB: + if ( + db_url == SQLITE_URL_PREFIX + or (db_url.startswith(SQLITE_URL_PREFIX) and ":memory:" in db_url) + ) and not ALLOW_IN_MEMORY_DB: raise vol.Invalid("In-memory SQLite database is not supported") return db_url diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index 532644c7feb..66a9818b4b8 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -8,7 +8,10 @@ from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-im DATA_INSTANCE = "recorder_instance" SQLITE_URL_PREFIX = "sqlite://" +MARIADB_URL_PREFIX = "mariadb://" +MARIADB_PYMYSQL_URL_PREFIX = "mariadb+pymysql://" MYSQLDB_URL_PREFIX = "mysql://" +MYSQLDB_PYMYSQL_URL_PREFIX = "mysql+pymysql://" DOMAIN = "recorder" CONF_DB_INTEGRITY_CHECK = "db_integrity_check" diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 0511b42ebe4..032f1ff1ec2 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -46,7 +46,10 @@ from .const import ( DB_WORKER_PREFIX, DOMAIN, KEEPALIVE_TIME, + MARIADB_PYMYSQL_URL_PREFIX, + MARIADB_URL_PREFIX, MAX_QUEUE_BACKLOG, + MYSQLDB_PYMYSQL_URL_PREFIX, MYSQLDB_URL_PREFIX, SQLITE_URL_PREFIX, SupportedDialect, @@ -1114,14 +1117,23 @@ class Recorder(threading.Thread): kwargs["pool_reset_on_return"] = None elif self.db_url.startswith(SQLITE_URL_PREFIX): kwargs["poolclass"] = RecorderPool - elif self.db_url.startswith(MYSQLDB_URL_PREFIX): - # If they have configured MySQLDB but don't have - # the MySQLDB module installed this will throw - # an ImportError which we suppress here since - # sqlalchemy will give them a better error when - # it tried to import it below. - with contextlib.suppress(ImportError): - kwargs["connect_args"] = {"conv": build_mysqldb_conv()} + elif self.db_url.startswith( + ( + MARIADB_URL_PREFIX, + MARIADB_PYMYSQL_URL_PREFIX, + MYSQLDB_URL_PREFIX, + MYSQLDB_PYMYSQL_URL_PREFIX, + ) + ): + kwargs["connect_args"] = {"charset": "utf8mb4"} + if self.db_url.startswith((MARIADB_URL_PREFIX, MYSQLDB_URL_PREFIX)): + # If they have configured MySQLDB but don't have + # the MySQLDB module installed this will throw + # an ImportError which we suppress here since + # sqlalchemy will give them a better error when + # it tried to import it below. + with contextlib.suppress(ImportError): + kwargs["connect_args"]["conv"] = build_mysqldb_conv() # Disable extended logging for non SQLite databases if not self.db_url.startswith(SQLITE_URL_PREFIX): diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 4a801574ebb..815af89198d 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -17,12 +17,15 @@ from homeassistant.components.recorder import ( CONF_AUTO_PURGE, CONF_AUTO_REPACK, CONF_COMMIT_INTERVAL, + CONF_DB_MAX_RETRIES, + CONF_DB_RETRY_WAIT, CONF_DB_URL, CONFIG_SCHEMA, DOMAIN, SQLITE_URL_PREFIX, Recorder, get_instance, + pool, ) from homeassistant.components.recorder.const import KEEPALIVE_TIME from homeassistant.components.recorder.db_schema import ( @@ -1626,3 +1629,129 @@ async def test_disable_echo(hass, db_url, echo, caplog): await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: db_url}}) create_engine_mock.assert_called_once() assert create_engine_mock.mock_calls[0][2].get("echo") == echo + + +@pytest.mark.parametrize( + "config_url, connect_args", + ( + ( + "mariadb://user:password@SERVER_IP/DB_NAME", + {"charset": "utf8mb4"}, + ), + ( + "mariadb+pymysql://user:password@SERVER_IP/DB_NAME", + {"charset": "utf8mb4"}, + ), + ( + "mysql://user:password@SERVER_IP/DB_NAME", + {"charset": "utf8mb4"}, + ), + ( + "mysql+pymysql://user:password@SERVER_IP/DB_NAME", + {"charset": "utf8mb4"}, + ), + ( + "mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4", + {"charset": "utf8mb4"}, + ), + ( + "mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other", + {"charset": "utf8mb4"}, + ), + ( + "postgresql://blabla", + None, + ), + ( + "sqlite://blabla", + None, + ), + ), +) +async def test_mysql_missing_utf8mb4(hass, config_url, connect_args): + """Test recorder fails to setup if charset=utf8mb4 is missing from db_url.""" + recorder_helper.async_initialize_recorder(hass) + + class MockEvent: + def listen(self, _, _2, callback): + callback(None, None) + + mock_event = MockEvent() + with patch( + "homeassistant.components.recorder.core.create_engine" + ) as create_engine_mock, patch( + "homeassistant.components.recorder.core.sqlalchemy_event", mock_event + ): + await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: config_url}}) + create_engine_mock.assert_called_once() + assert create_engine_mock.mock_calls[0][2].get("connect_args") == connect_args + + +@pytest.mark.parametrize( + "config_url", + ( + "mysql://user:password@SERVER_IP/DB_NAME", + "mysql://user:password@SERVER_IP/DB_NAME?charset=utf8mb4", + "mysql://user:password@SERVER_IP/DB_NAME?blah=bleh&charset=other", + ), +) +async def test_connect_args_priority(hass, config_url): + """Test connect_args has priority over URL query.""" + connect_params = [] + recorder_helper.async_initialize_recorder(hass) + + class MockDialect: + """Non functioning dialect, good enough that SQLAlchemy tries connecting.""" + + __bases__ = [] + _has_events = False + + def __init__(*args, **kwargs): + ... + + def connect(self, *args, **params): + nonlocal connect_params + connect_params.append(params) + return True + + def create_connect_args(self, url): + return ([], {"charset": "invalid"}) + + @classmethod + def dbapi(cls): + ... + + def engine_created(*args): + ... + + def get_dialect_pool_class(self, *args): + return pool.RecorderPool + + def initialize(*args): + ... + + def on_connect_url(self, url): + return False + + class MockEntrypoint: + def engine_created(*_): + ... + + def get_dialect_cls(*_): + return MockDialect + + with patch("sqlalchemy.engine.url.URL._get_entrypoint", MockEntrypoint), patch( + "sqlalchemy.engine.create.util.get_cls_kwargs", return_value=["echo"] + ): + await async_setup_component( + hass, + DOMAIN, + { + DOMAIN: { + CONF_DB_URL: config_url, + CONF_DB_MAX_RETRIES: 1, + CONF_DB_RETRY_WAIT: 0, + } + }, + ) + assert connect_params == [{"charset": "utf8mb4"}]