From 01c4ca27499a011bfe6d74380613d5fc9044b923 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Tue, 4 Jun 2024 06:20:18 +0200 Subject: [PATCH] Recover mqtt abbrevations optimizations (#118762) Co-authored-by: J. Nick Koston --- homeassistant/components/mqtt/discovery.py | 143 ++++++++++++--------- tests/components/mqtt/test_discovery.py | 4 +- 2 files changed, 86 insertions(+), 61 deletions(-) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index e8a3ed9a8cb..0d93af26a57 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -41,6 +41,10 @@ from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage from .schemas import MQTT_ORIGIN_INFO_SCHEMA from .util import async_forward_entry_setup_and_setup_discovery +ABBREVIATIONS_SET = set(ABBREVIATIONS) +DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS) +ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS) + _LOGGER = logging.getLogger(__name__) TOPIC_MATCHER = re.compile( @@ -105,6 +109,82 @@ def async_log_discovery_origin_info( ) +@callback +def _replace_abbreviations( + payload: Any | dict[str, Any], + abbreviations: dict[str, str], + abbreviations_set: set[str], +) -> None: + """Replace abbreviations in an MQTT discovery payload.""" + if not isinstance(payload, dict): + return + for key in abbreviations_set.intersection(payload): + payload[abbreviations[key]] = payload.pop(key) + + +@callback +def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None: + """Replace all abbreviations in an MQTT discovery payload.""" + + _replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET) + + if CONF_ORIGIN in discovery_payload: + _replace_abbreviations( + discovery_payload[CONF_ORIGIN], + ORIGIN_ABBREVIATIONS, + ORIGIN_ABBREVIATIONS_SET, + ) + + if CONF_DEVICE in discovery_payload: + _replace_abbreviations( + discovery_payload[CONF_DEVICE], + DEVICE_ABBREVIATIONS, + DEVICE_ABBREVIATIONS_SET, + ) + + if CONF_AVAILABILITY in discovery_payload: + for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]): + _replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET) + + +@callback +def _replace_topic_base(discovery_payload: dict[str, Any]) -> None: + """Replace topic base in MQTT discovery data.""" + base = discovery_payload.pop(TOPIC_BASE) + for key, value in discovery_payload.items(): + if isinstance(value, str) and value: + if value[0] == TOPIC_BASE and key.endswith("topic"): + discovery_payload[key] = f"{base}{value[1:]}" + if value[-1] == TOPIC_BASE and key.endswith("topic"): + discovery_payload[key] = f"{value[:-1]}{base}" + if discovery_payload.get(CONF_AVAILABILITY): + for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]): + if not isinstance(availability_conf, dict): + continue + if topic := str(availability_conf.get(CONF_TOPIC)): + if topic[0] == TOPIC_BASE: + availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}" + if topic[-1] == TOPIC_BASE: + availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}" + + +@callback +def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool: + """Parse and validate origin info from a single component discovery payload.""" + if CONF_ORIGIN not in discovery_payload: + return True + try: + MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN]) + except Exception as exc: # noqa:BLE001 + _LOGGER.warning( + "Unable to parse origin information from discovery message: %s, got %s", + exc, + discovery_payload[CONF_ORIGIN], + ) + return False + return True + + async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry ) -> None: @@ -168,67 +248,14 @@ async def async_start( # noqa: C901 except ValueError: _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) return + _replace_all_abbreviations(discovery_payload) + if not _valid_origin_info(discovery_payload): + return + if TOPIC_BASE in discovery_payload: + _replace_topic_base(discovery_payload) else: discovery_payload = MQTTDiscoveryPayload({}) - for key in list(discovery_payload): - abbreviated_key = key - key = ABBREVIATIONS.get(key, key) - discovery_payload[key] = discovery_payload.pop(abbreviated_key) - - if CONF_DEVICE in discovery_payload: - device = discovery_payload[CONF_DEVICE] - for key in list(device): - abbreviated_key = key - key = DEVICE_ABBREVIATIONS.get(key, key) - device[key] = device.pop(abbreviated_key) - - if CONF_ORIGIN in discovery_payload: - origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN] - try: - for key in list(origin_info): - abbreviated_key = key - key = ORIGIN_ABBREVIATIONS.get(key, key) - origin_info[key] = origin_info.pop(abbreviated_key) - MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN]) - except Exception: # noqa: BLE001 - _LOGGER.warning( - "Unable to parse origin information " - "from discovery message, got %s", - discovery_payload[CONF_ORIGIN], - ) - return - - if CONF_AVAILABILITY in discovery_payload: - for availability_conf in cv.ensure_list( - discovery_payload[CONF_AVAILABILITY] - ): - if isinstance(availability_conf, dict): - for key in list(availability_conf): - abbreviated_key = key - key = ABBREVIATIONS.get(key, key) - availability_conf[key] = availability_conf.pop(abbreviated_key) - - if TOPIC_BASE in discovery_payload: - base = discovery_payload.pop(TOPIC_BASE) - for key, value in discovery_payload.items(): - if isinstance(value, str) and value: - if value[0] == TOPIC_BASE and key.endswith("topic"): - discovery_payload[key] = f"{base}{value[1:]}" - if value[-1] == TOPIC_BASE and key.endswith("topic"): - discovery_payload[key] = f"{value[:-1]}{base}" - if discovery_payload.get(CONF_AVAILABILITY): - for availability_conf in cv.ensure_list( - discovery_payload[CONF_AVAILABILITY] - ): - if not isinstance(availability_conf, dict): - continue - if topic := str(availability_conf.get(CONF_TOPIC)): - if topic[0] == TOPIC_BASE: - availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}" - if topic[-1] == TOPIC_BASE: - availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}" - # If present, the node_id will be included in the discovered object id discovery_id = f"{node_id} {object_id}" if node_id else object_id discovery_hash = (component, discovery_id) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 2e1f78c1bd4..020ab4a09a9 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -291,9 +291,7 @@ async def test_discovery_with_invalid_integration_info( state = hass.states.get("binary_sensor.beer") assert state is None - assert ( - "Unable to parse origin information from discovery message, got" in caplog.text - ) + assert "Unable to parse origin information from discovery message" in caplog.text async def test_discover_fan(