Recover mqtt abbrevations optimizations (#118762)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Jan Bouwhuis 2024-06-04 06:20:18 +02:00 committed by GitHub
parent 289263087c
commit e799270578
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 61 deletions

View File

@ -41,6 +41,10 @@ from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
from .util import async_forward_entry_setup_and_setup_discovery
ABBREVIATIONS_SET = set(ABBREVIATIONS)
DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS)
ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS)
_LOGGER = logging.getLogger(__name__)
TOPIC_MATCHER = re.compile(
@ -105,6 +109,82 @@ def async_log_discovery_origin_info(
)
@callback
def _replace_abbreviations(
payload: Any | dict[str, Any],
abbreviations: dict[str, str],
abbreviations_set: set[str],
) -> None:
"""Replace abbreviations in an MQTT discovery payload."""
if not isinstance(payload, dict):
return
for key in abbreviations_set.intersection(payload):
payload[abbreviations[key]] = payload.pop(key)
@callback
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
"""Replace all abbreviations in an MQTT discovery payload."""
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
if CONF_ORIGIN in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_ORIGIN],
ORIGIN_ABBREVIATIONS,
ORIGIN_ABBREVIATIONS_SET,
)
if CONF_DEVICE in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_DEVICE],
DEVICE_ABBREVIATIONS,
DEVICE_ABBREVIATIONS_SET,
)
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
@callback
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
"""Replace topic base in MQTT discovery data."""
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
if not isinstance(availability_conf, dict):
continue
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
@callback
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
"""Parse and validate origin info from a single component discovery payload."""
if CONF_ORIGIN not in discovery_payload:
return True
try:
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
except Exception as exc: # noqa:BLE001
_LOGGER.warning(
"Unable to parse origin information from discovery message: %s, got %s",
exc,
discovery_payload[CONF_ORIGIN],
)
return False
return True
async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None:
@ -168,67 +248,14 @@ async def async_start( # noqa: C901
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
_replace_all_abbreviations(discovery_payload)
if not _valid_origin_info(discovery_payload):
return
if TOPIC_BASE in discovery_payload:
_replace_topic_base(discovery_payload)
else:
discovery_payload = MQTTDiscoveryPayload({})
for key in list(discovery_payload):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
if CONF_DEVICE in discovery_payload:
device = discovery_payload[CONF_DEVICE]
for key in list(device):
abbreviated_key = key
key = DEVICE_ABBREVIATIONS.get(key, key)
device[key] = device.pop(abbreviated_key)
if CONF_ORIGIN in discovery_payload:
origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN]
try:
for key in list(origin_info):
abbreviated_key = key
key = ORIGIN_ABBREVIATIONS.get(key, key)
origin_info[key] = origin_info.pop(abbreviated_key)
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
except Exception: # noqa: BLE001
_LOGGER.warning(
"Unable to parse origin information "
"from discovery message, got %s",
discovery_payload[CONF_ORIGIN],
)
return
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if isinstance(availability_conf, dict):
for key in list(availability_conf):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
availability_conf[key] = availability_conf.pop(abbreviated_key)
if TOPIC_BASE in discovery_payload:
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if not isinstance(availability_conf, dict):
continue
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
# If present, the node_id will be included in the discovered object id
discovery_id = f"{node_id} {object_id}" if node_id else object_id
discovery_hash = (component, discovery_id)

View File

@ -291,9 +291,7 @@ async def test_discovery_with_invalid_integration_info(
state = hass.states.get("binary_sensor.beer")
assert state is None
assert (
"Unable to parse origin information from discovery message, got" in caplog.text
)
assert "Unable to parse origin information from discovery message" in caplog.text
async def test_discover_fan(