mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Recover mqtt abbrevations optimizations (#118762)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
289263087c
commit
e799270578
@ -41,6 +41,10 @@ from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
|
||||
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
|
||||
from .util import async_forward_entry_setup_and_setup_discovery
|
||||
|
||||
ABBREVIATIONS_SET = set(ABBREVIATIONS)
|
||||
DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS)
|
||||
ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TOPIC_MATCHER = re.compile(
|
||||
@ -105,6 +109,82 @@ def async_log_discovery_origin_info(
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _replace_abbreviations(
|
||||
payload: Any | dict[str, Any],
|
||||
abbreviations: dict[str, str],
|
||||
abbreviations_set: set[str],
|
||||
) -> None:
|
||||
"""Replace abbreviations in an MQTT discovery payload."""
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
for key in abbreviations_set.intersection(payload):
|
||||
payload[abbreviations[key]] = payload.pop(key)
|
||||
|
||||
|
||||
@callback
|
||||
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
|
||||
"""Replace all abbreviations in an MQTT discovery payload."""
|
||||
|
||||
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||
|
||||
if CONF_ORIGIN in discovery_payload:
|
||||
_replace_abbreviations(
|
||||
discovery_payload[CONF_ORIGIN],
|
||||
ORIGIN_ABBREVIATIONS,
|
||||
ORIGIN_ABBREVIATIONS_SET,
|
||||
)
|
||||
|
||||
if CONF_DEVICE in discovery_payload:
|
||||
_replace_abbreviations(
|
||||
discovery_payload[CONF_DEVICE],
|
||||
DEVICE_ABBREVIATIONS,
|
||||
DEVICE_ABBREVIATIONS_SET,
|
||||
)
|
||||
|
||||
if CONF_AVAILABILITY in discovery_payload:
|
||||
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
|
||||
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||
|
||||
|
||||
@callback
|
||||
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
|
||||
"""Replace topic base in MQTT discovery data."""
|
||||
base = discovery_payload.pop(TOPIC_BASE)
|
||||
for key, value in discovery_payload.items():
|
||||
if isinstance(value, str) and value:
|
||||
if value[0] == TOPIC_BASE and key.endswith("topic"):
|
||||
discovery_payload[key] = f"{base}{value[1:]}"
|
||||
if value[-1] == TOPIC_BASE and key.endswith("topic"):
|
||||
discovery_payload[key] = f"{value[:-1]}{base}"
|
||||
if discovery_payload.get(CONF_AVAILABILITY):
|
||||
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
|
||||
if not isinstance(availability_conf, dict):
|
||||
continue
|
||||
if topic := str(availability_conf.get(CONF_TOPIC)):
|
||||
if topic[0] == TOPIC_BASE:
|
||||
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
|
||||
if topic[-1] == TOPIC_BASE:
|
||||
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
|
||||
|
||||
|
||||
@callback
|
||||
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
|
||||
"""Parse and validate origin info from a single component discovery payload."""
|
||||
if CONF_ORIGIN not in discovery_payload:
|
||||
return True
|
||||
try:
|
||||
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
|
||||
except Exception as exc: # noqa:BLE001
|
||||
_LOGGER.warning(
|
||||
"Unable to parse origin information from discovery message: %s, got %s",
|
||||
exc,
|
||||
discovery_payload[CONF_ORIGIN],
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def async_start( # noqa: C901
|
||||
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
|
||||
) -> None:
|
||||
@ -168,67 +248,14 @@ async def async_start( # noqa: C901
|
||||
except ValueError:
|
||||
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
|
||||
return
|
||||
_replace_all_abbreviations(discovery_payload)
|
||||
if not _valid_origin_info(discovery_payload):
|
||||
return
|
||||
if TOPIC_BASE in discovery_payload:
|
||||
_replace_topic_base(discovery_payload)
|
||||
else:
|
||||
discovery_payload = MQTTDiscoveryPayload({})
|
||||
|
||||
for key in list(discovery_payload):
|
||||
abbreviated_key = key
|
||||
key = ABBREVIATIONS.get(key, key)
|
||||
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
|
||||
|
||||
if CONF_DEVICE in discovery_payload:
|
||||
device = discovery_payload[CONF_DEVICE]
|
||||
for key in list(device):
|
||||
abbreviated_key = key
|
||||
key = DEVICE_ABBREVIATIONS.get(key, key)
|
||||
device[key] = device.pop(abbreviated_key)
|
||||
|
||||
if CONF_ORIGIN in discovery_payload:
|
||||
origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN]
|
||||
try:
|
||||
for key in list(origin_info):
|
||||
abbreviated_key = key
|
||||
key = ORIGIN_ABBREVIATIONS.get(key, key)
|
||||
origin_info[key] = origin_info.pop(abbreviated_key)
|
||||
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.warning(
|
||||
"Unable to parse origin information "
|
||||
"from discovery message, got %s",
|
||||
discovery_payload[CONF_ORIGIN],
|
||||
)
|
||||
return
|
||||
|
||||
if CONF_AVAILABILITY in discovery_payload:
|
||||
for availability_conf in cv.ensure_list(
|
||||
discovery_payload[CONF_AVAILABILITY]
|
||||
):
|
||||
if isinstance(availability_conf, dict):
|
||||
for key in list(availability_conf):
|
||||
abbreviated_key = key
|
||||
key = ABBREVIATIONS.get(key, key)
|
||||
availability_conf[key] = availability_conf.pop(abbreviated_key)
|
||||
|
||||
if TOPIC_BASE in discovery_payload:
|
||||
base = discovery_payload.pop(TOPIC_BASE)
|
||||
for key, value in discovery_payload.items():
|
||||
if isinstance(value, str) and value:
|
||||
if value[0] == TOPIC_BASE and key.endswith("topic"):
|
||||
discovery_payload[key] = f"{base}{value[1:]}"
|
||||
if value[-1] == TOPIC_BASE and key.endswith("topic"):
|
||||
discovery_payload[key] = f"{value[:-1]}{base}"
|
||||
if discovery_payload.get(CONF_AVAILABILITY):
|
||||
for availability_conf in cv.ensure_list(
|
||||
discovery_payload[CONF_AVAILABILITY]
|
||||
):
|
||||
if not isinstance(availability_conf, dict):
|
||||
continue
|
||||
if topic := str(availability_conf.get(CONF_TOPIC)):
|
||||
if topic[0] == TOPIC_BASE:
|
||||
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
|
||||
if topic[-1] == TOPIC_BASE:
|
||||
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
|
||||
|
||||
# If present, the node_id will be included in the discovered object id
|
||||
discovery_id = f"{node_id} {object_id}" if node_id else object_id
|
||||
discovery_hash = (component, discovery_id)
|
||||
|
@ -291,9 +291,7 @@ async def test_discovery_with_invalid_integration_info(
|
||||
state = hass.states.get("binary_sensor.beer")
|
||||
|
||||
assert state is None
|
||||
assert (
|
||||
"Unable to parse origin information from discovery message, got" in caplog.text
|
||||
)
|
||||
assert "Unable to parse origin information from discovery message" in caplog.text
|
||||
|
||||
|
||||
async def test_discover_fan(
|
||||
|
Loading…
x
Reference in New Issue
Block a user