mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Recover mqtt abbrevations optimizations (#118762)
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
289263087c
commit
e799270578
@ -41,6 +41,10 @@ from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
|
|||||||
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
|
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
|
||||||
from .util import async_forward_entry_setup_and_setup_discovery
|
from .util import async_forward_entry_setup_and_setup_discovery
|
||||||
|
|
||||||
|
ABBREVIATIONS_SET = set(ABBREVIATIONS)
|
||||||
|
DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS)
|
||||||
|
ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
TOPIC_MATCHER = re.compile(
|
TOPIC_MATCHER = re.compile(
|
||||||
@ -105,6 +109,82 @@ def async_log_discovery_origin_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _replace_abbreviations(
|
||||||
|
payload: Any | dict[str, Any],
|
||||||
|
abbreviations: dict[str, str],
|
||||||
|
abbreviations_set: set[str],
|
||||||
|
) -> None:
|
||||||
|
"""Replace abbreviations in an MQTT discovery payload."""
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return
|
||||||
|
for key in abbreviations_set.intersection(payload):
|
||||||
|
payload[abbreviations[key]] = payload.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
|
||||||
|
"""Replace all abbreviations in an MQTT discovery payload."""
|
||||||
|
|
||||||
|
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||||
|
|
||||||
|
if CONF_ORIGIN in discovery_payload:
|
||||||
|
_replace_abbreviations(
|
||||||
|
discovery_payload[CONF_ORIGIN],
|
||||||
|
ORIGIN_ABBREVIATIONS,
|
||||||
|
ORIGIN_ABBREVIATIONS_SET,
|
||||||
|
)
|
||||||
|
|
||||||
|
if CONF_DEVICE in discovery_payload:
|
||||||
|
_replace_abbreviations(
|
||||||
|
discovery_payload[CONF_DEVICE],
|
||||||
|
DEVICE_ABBREVIATIONS,
|
||||||
|
DEVICE_ABBREVIATIONS_SET,
|
||||||
|
)
|
||||||
|
|
||||||
|
if CONF_AVAILABILITY in discovery_payload:
|
||||||
|
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
|
||||||
|
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
|
||||||
|
"""Replace topic base in MQTT discovery data."""
|
||||||
|
base = discovery_payload.pop(TOPIC_BASE)
|
||||||
|
for key, value in discovery_payload.items():
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
if value[0] == TOPIC_BASE and key.endswith("topic"):
|
||||||
|
discovery_payload[key] = f"{base}{value[1:]}"
|
||||||
|
if value[-1] == TOPIC_BASE and key.endswith("topic"):
|
||||||
|
discovery_payload[key] = f"{value[:-1]}{base}"
|
||||||
|
if discovery_payload.get(CONF_AVAILABILITY):
|
||||||
|
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
|
||||||
|
if not isinstance(availability_conf, dict):
|
||||||
|
continue
|
||||||
|
if topic := str(availability_conf.get(CONF_TOPIC)):
|
||||||
|
if topic[0] == TOPIC_BASE:
|
||||||
|
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
|
||||||
|
if topic[-1] == TOPIC_BASE:
|
||||||
|
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
|
||||||
|
"""Parse and validate origin info from a single component discovery payload."""
|
||||||
|
if CONF_ORIGIN not in discovery_payload:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
|
||||||
|
except Exception as exc: # noqa:BLE001
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Unable to parse origin information from discovery message: %s, got %s",
|
||||||
|
exc,
|
||||||
|
discovery_payload[CONF_ORIGIN],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_start( # noqa: C901
|
async def async_start( # noqa: C901
|
||||||
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
|
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -168,67 +248,14 @@ async def async_start( # noqa: C901
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
|
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
|
||||||
return
|
return
|
||||||
|
_replace_all_abbreviations(discovery_payload)
|
||||||
|
if not _valid_origin_info(discovery_payload):
|
||||||
|
return
|
||||||
|
if TOPIC_BASE in discovery_payload:
|
||||||
|
_replace_topic_base(discovery_payload)
|
||||||
else:
|
else:
|
||||||
discovery_payload = MQTTDiscoveryPayload({})
|
discovery_payload = MQTTDiscoveryPayload({})
|
||||||
|
|
||||||
for key in list(discovery_payload):
|
|
||||||
abbreviated_key = key
|
|
||||||
key = ABBREVIATIONS.get(key, key)
|
|
||||||
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
|
|
||||||
|
|
||||||
if CONF_DEVICE in discovery_payload:
|
|
||||||
device = discovery_payload[CONF_DEVICE]
|
|
||||||
for key in list(device):
|
|
||||||
abbreviated_key = key
|
|
||||||
key = DEVICE_ABBREVIATIONS.get(key, key)
|
|
||||||
device[key] = device.pop(abbreviated_key)
|
|
||||||
|
|
||||||
if CONF_ORIGIN in discovery_payload:
|
|
||||||
origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN]
|
|
||||||
try:
|
|
||||||
for key in list(origin_info):
|
|
||||||
abbreviated_key = key
|
|
||||||
key = ORIGIN_ABBREVIATIONS.get(key, key)
|
|
||||||
origin_info[key] = origin_info.pop(abbreviated_key)
|
|
||||||
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Unable to parse origin information "
|
|
||||||
"from discovery message, got %s",
|
|
||||||
discovery_payload[CONF_ORIGIN],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if CONF_AVAILABILITY in discovery_payload:
|
|
||||||
for availability_conf in cv.ensure_list(
|
|
||||||
discovery_payload[CONF_AVAILABILITY]
|
|
||||||
):
|
|
||||||
if isinstance(availability_conf, dict):
|
|
||||||
for key in list(availability_conf):
|
|
||||||
abbreviated_key = key
|
|
||||||
key = ABBREVIATIONS.get(key, key)
|
|
||||||
availability_conf[key] = availability_conf.pop(abbreviated_key)
|
|
||||||
|
|
||||||
if TOPIC_BASE in discovery_payload:
|
|
||||||
base = discovery_payload.pop(TOPIC_BASE)
|
|
||||||
for key, value in discovery_payload.items():
|
|
||||||
if isinstance(value, str) and value:
|
|
||||||
if value[0] == TOPIC_BASE and key.endswith("topic"):
|
|
||||||
discovery_payload[key] = f"{base}{value[1:]}"
|
|
||||||
if value[-1] == TOPIC_BASE and key.endswith("topic"):
|
|
||||||
discovery_payload[key] = f"{value[:-1]}{base}"
|
|
||||||
if discovery_payload.get(CONF_AVAILABILITY):
|
|
||||||
for availability_conf in cv.ensure_list(
|
|
||||||
discovery_payload[CONF_AVAILABILITY]
|
|
||||||
):
|
|
||||||
if not isinstance(availability_conf, dict):
|
|
||||||
continue
|
|
||||||
if topic := str(availability_conf.get(CONF_TOPIC)):
|
|
||||||
if topic[0] == TOPIC_BASE:
|
|
||||||
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
|
|
||||||
if topic[-1] == TOPIC_BASE:
|
|
||||||
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
|
|
||||||
|
|
||||||
# If present, the node_id will be included in the discovered object id
|
# If present, the node_id will be included in the discovered object id
|
||||||
discovery_id = f"{node_id} {object_id}" if node_id else object_id
|
discovery_id = f"{node_id} {object_id}" if node_id else object_id
|
||||||
discovery_hash = (component, discovery_id)
|
discovery_hash = (component, discovery_id)
|
||||||
|
@ -291,9 +291,7 @@ async def test_discovery_with_invalid_integration_info(
|
|||||||
state = hass.states.get("binary_sensor.beer")
|
state = hass.states.get("binary_sensor.beer")
|
||||||
|
|
||||||
assert state is None
|
assert state is None
|
||||||
assert (
|
assert "Unable to parse origin information from discovery message" in caplog.text
|
||||||
"Unable to parse origin information from discovery message, got" in caplog.text
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_discover_fan(
|
async def test_discover_fan(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user